osbm commited on
Commit
3953219
·
1 Parent(s): e42541c

Upload 9 files

Browse files
prostate/data.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create dataloaders form csv file
2
+
3
+ ## ---------- imports ----------
4
+ import os
5
+ import torch
6
+ import shutil
7
+ import numpy as np
8
+ import pandas as pd
9
+ from typing import Union
10
+ from monai.utils import first
11
+ from functools import partial
12
+ from collections import namedtuple
13
+ from monai.data import DataLoader as MonaiDataLoader
14
+
15
+ from . import transforms
16
+ from .utils import num_workers
17
+
18
+
19
+ def import_dataset(config: dict):
20
+ if config.data.dataset_type == 'persistent':
21
+ from monai.data import PersistentDataset
22
+ if os.path.exists(config.data.cache_dir):
23
+ shutil.rmtree(config.data.cache_dir) # rm previous cache DS
24
+ os.makedirs(config.data.cache_dir, exist_ok = True)
25
+ Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir)
26
+ elif config.data.dataset_type == 'cache':
27
+ from monai.data import CacheDataset
28
+ raise NotImplementedError('CacheDataset not yet implemented')
29
+ else:
30
+ from monai.data import Dataset
31
+ return Dataset
32
+
33
+
34
+ class DataLoader(MonaiDataLoader):
35
+ "overwrite monai DataLoader for enhanced viewing capabilities"
36
+
37
+ def show_batch(self,
38
+ image_key: str='image',
39
+ label_key: str='label',
40
+ image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2),
41
+ label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)):
42
+ """Args:
43
+ image_key: dict key name for image to view
44
+ label_key: dict kex name for corresponding label. Can be a tensor or str
45
+ image_transform: transform input before it is passed to the viewer to ensure
46
+ ndim of the image is equal to 3 and image is oriented correctly
47
+ label_transform: transform labels before passed to the viewer, to ensure
48
+ segmentations masks have same shape and orientations as images. Should be
49
+ identity function of labels are str.
50
+ """
51
+ from .viewer import ListViewer
52
+
53
+ batch = first(self)
54
+ image = torch.unbind(batch[image_key], 0)
55
+ label = torch.unbind(batch[label_key], 0)
56
+
57
+ ListViewer([image_transform(im) for im in image],
58
+ [label_transform(im) for im in label]).show()
59
+
60
+ # TODO
61
+ ## Work with 3 dataloaders
62
+
63
+ def segmentation_dataloaders(config: dict,
64
+ train: bool = None,
65
+ valid: bool = None,
66
+ test: bool = None,
67
+ ):
68
+ """Create segmentation dataloaders
69
+ Args:
70
+ config: config file
71
+ train: whether to return a train DataLoader
72
+ valid: whether to return a valid DataLoader
73
+ test: whether to return a test DateLoader
74
+ Args from config:
75
+ data_dir: base directory for the data
76
+ csv_name: path to csv file containing filenames and paths
77
+ image_cols: columns in csv containing path to images
78
+ label_cols: columns in csv containing path to label files
79
+ dataset_type: PersistentDataset, CacheDataset and Dataset are supported
80
+ cache_dir: cache directory to be used by PersistentDataset
81
+ batch_size: batch size for training. Valid and test are always 1
82
+ debug: run with reduced number of images
83
+ Returns:
84
+ list of:
85
+ train_loader: DataLoader (optional, if train==True)
86
+ valid_loader: DataLoader (optional, if valid==True)
87
+ test_loader: DataLoader (optional, if test==True)
88
+ """
89
+
90
+ ## parse needed rguments from config
91
+ if train is None: train = config.data.train
92
+ if valid is None: valid = config.data.valid
93
+ if test is None: test = config.data.test
94
+
95
+ data_dir = config.data.data_dir
96
+ train_csv = config.data.train_csv
97
+ valid_csv = config.data.valid_csv
98
+ test_csv = config.data.test_csv
99
+ image_cols = config.data.image_cols
100
+ label_cols = config.data.label_cols
101
+ dataset_type = config.data.dataset_type
102
+ cache_dir = config.data.cache_dir
103
+ batch_size = config.data.batch_size
104
+ debug = config.debug
105
+
106
+ ## ---------- data dicts ----------
107
+
108
+ # first a global data dict, containing only the filepath from image_cols and label_cols is created. For this,
109
+ # the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an
110
+ # individual dict, as expected by monai
111
+
112
+ if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols]
113
+ if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols]
114
+
115
+ train_df = pd.read_csv(train_csv)
116
+ valid_df = pd.read_csv(valid_csv)
117
+ test_df = pd.read_csv(test_csv)
118
+ if debug:
119
+ train_df = train_df.sample(25)
120
+ valid_df = valid_df.sample(5)
121
+
122
+ train_df['split']='train'
123
+ valid_df['split']='valid'
124
+ test_df['split']='test'
125
+ whole_df = []
126
+ if train: whole_df += [train_df]
127
+ if valid: whole_df += [valid_df]
128
+ if test: whole_df += [test_df]
129
+ df = pd.concat(whole_df)
130
+ cols = image_cols + label_cols
131
+ for col in cols:
132
+ # create absolute file name from relative fn in df and data_dir
133
+ df[col] = [os.path.join(data_dir, fn) for fn in df[col]]
134
+ if not os.path.exists(list(df[col])[0]):
135
+ raise FileNotFoundError(list(df[col])[0])
136
+
137
+
138
+ data_dict = [dict(row[1]) for row in df[cols].iterrows()]
139
+ # data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer.
140
+ # The data_dict looks like this:
141
+ # [
142
+ # {'image_col_1': 'data_dir/path/to/image1',
143
+ # 'image_col_2': 'data_dir/path/to/image2'
144
+ # 'label_col_1': 'data_dir/path/to/label1},
145
+ # {'image_col_1': 'data_dir/path/to/image1',
146
+ # 'image_col_2': 'data_dir/path/to/image2'
147
+ # 'label_col_1': 'data_dir/path/to/label1},
148
+ # ...]
149
+ # Filename should now be absolute or relative to working directory
150
+
151
+ # now we create separate data dicts for train, valid and test data respectively
152
+ assert train or test or valid, 'No dataset type is specified (train/valid or test)'
153
+
154
+ if test:
155
+ test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test')))
156
+
157
+ if valid:
158
+ val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid')))
159
+
160
+ if train:
161
+ train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train')))
162
+
163
+ # transforms are specified in transforms.py and are just loaded here
164
+ if train: train_transforms = transforms.get_train_transforms(config)
165
+ if valid: val_transforms = transforms.get_val_transforms(config)
166
+ if test: test_transforms = transforms.get_test_transforms(config)
167
+
168
+
169
+ ## ---------- construct dataloaders ----------
170
+ Dataset=import_dataset(config)
171
+ data_loaders = []
172
+ if train:
173
+ train_ds = Dataset(
174
+ data=train_files,
175
+ transform=train_transforms
176
+ )
177
+ train_loader = DataLoader(
178
+ train_ds,
179
+ batch_size=batch_size,
180
+ num_workers=num_workers(),
181
+ shuffle=True
182
+ )
183
+ data_loaders.append(train_loader)
184
+
185
+ if valid:
186
+ val_ds = Dataset(
187
+ data=val_files,
188
+ transform=val_transforms
189
+ )
190
+ val_loader = DataLoader(
191
+ val_ds,
192
+ batch_size=1,
193
+ num_workers=num_workers(),
194
+ shuffle=False
195
+ )
196
+ data_loaders.append(val_loader)
197
+
198
+ if test:
199
+ test_ds = Dataset(
200
+ data=test_files,
201
+ transform=test_transforms
202
+ )
203
+ test_loader = DataLoader(
204
+ test_ds,
205
+ batch_size=1,
206
+ num_workers=num_workers(),
207
+ shuffle=False
208
+ )
209
+ data_loaders.append(test_loader)
210
+
211
+ # if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders,
212
+ # so it is clear which DataLoader is train/valid or test
213
+
214
+ if len(data_loaders) == 1:
215
+ return data_loaders[0]
216
+ else:
217
+ DataLoaders = namedtuple(
218
+ 'DataLoaders',
219
+ # create str with specification of loader type if train and test are true but
220
+ # valid is false string will be 'train test'
221
+ ' '.join(
222
+ [
223
+ 'train' if train else '',
224
+ 'valid' if valid else '',
225
+ 'test' if test else ''
226
+ ]
227
+ ).strip()
228
+ )
229
+ return DataLoaders(*data_loaders)
prostate/loss.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import monai
2
+ from .utils import load_config
3
+
4
+
5
+ def get_loss(config: dict):
6
+ """Create a loss function of `type` with specific keyword arguments from config.
7
+ Example:
8
+
9
+ config.loss
10
+ >>> {'DiceCELoss': {'include_background': False, 'softmax': True, 'to_onehot_y': True}}
11
+
12
+ get_loss(config)
13
+ >>> DiceCELoss(
14
+ >>> (dice): DiceLoss()
15
+ >>> (cross_entropy): CrossEntropyLoss()
16
+ >>> )
17
+
18
+ """
19
+ loss_type = list(config.loss.keys())[0]
20
+ loss_config = config.loss[loss_type]
21
+ loss_fun = getattr(monai.losses, loss_type)
22
+ loss = loss_fun(**loss_config)
23
+ return loss
prostate/model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create a standard UNet
2
+
3
+ from monai.networks.nets import UNet
4
+
5
+ def get_model(config: dict):
6
+ return UNet(
7
+ spatial_dims=config.ndim,
8
+ in_channels=len(config.data.image_cols),
9
+ out_channels=config.model.out_channels,
10
+ channels=config.model.channels,
11
+ strides=config.model.strides,
12
+ num_res_units=config.model.num_res_units,
13
+ act=config.model.act,
14
+ norm=config.model.norm,
15
+ dropout=config.model.dropout,
16
+ )
prostate/optimizer.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import monai
3
+ from .utils import load_config
4
+
5
+
6
+ def get_optimizer(model: torch.nn.Module,
7
+ config: dict):
8
+ """Create an optimizer of `type` with specific keyword arguments from config.
9
+ Example:
10
+
11
+ config.optimizer
12
+ >>> {'Novograd': {'lr': 0.001, 'weight_decay': 0.01}}
13
+
14
+ get_optimizer(model, config)
15
+ >>> Novograd (
16
+ >>> Parameter Group 0
17
+ >>> amsgrad: False
18
+ >>> betas: (0.9, 0.999)
19
+ >>> eps: 1e-08
20
+ >>> grad_averaging: False
21
+ >>> lr: 0.0001
22
+ >>> weight_decay: 0.001
23
+ >>> )
24
+
25
+ """
26
+ optimizer_type = list(config.optimizer.keys())[0]
27
+ opt_config = config.optimizer[optimizer_type]
28
+ if hasattr(torch.optim, optimizer_type):
29
+ optimizer_fun = getattr(torch.optim, optimizer_type)
30
+ elif hasattr(monai.optimizers, optimizer_type):
31
+ optimizer_fun = getattr(monai.optimizers, optimizer_type)
32
+ else:
33
+ raise ValueError(f'Optimizer {optimizer_type} not found')
34
+ optimizer = optimizer_fun(model.parameters(), **opt_config)
35
+ return optimizer
36
+
37
+
prostate/report.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import cv2
4
+ import tqdm
5
+ import torch
6
+ import imageio
7
+ import numpy as np
8
+ import pandas as pd
9
+ import matplotlib.pyplot as plt
10
+
11
+ class ReportGenerator():
12
+ "Generate markdown document, summarizing the training"
13
+
14
+ def __init__(self, run_id, out_dir=None, log_dir=None):
15
+
16
+ self.run_id, self.out_dir, self.log_dir = run_id, out_dir, log_dir
17
+
18
+ if log_dir:
19
+ self.train_logs = pd.read_csv(os.path.join(log_dir, 'train_logs.csv'))
20
+ self.metric_logs = pd.read_csv(os.path.join(log_dir, 'metric_logs.csv'))
21
+ if out_dir:
22
+ self.dice = pd.read_csv(os.path.join(out_dir, 'MeanDice_raw.csv'))
23
+ self.hausdorf = pd.read_csv(os.path.join(out_dir, 'HausdorffDistance_raw.csv'))
24
+ self.surface = pd.read_csv(os.path.join(out_dir, 'SurfaceDistance_raw.csv'))
25
+
26
+ self.mean_metrics = pd.DataFrame(
27
+ {"mean_dice" : [round(np.mean(self.dice[col]),3) for col in self.dice if col.startswith('class')],
28
+ "mean_hausdorf" : [round(np.mean(self.hausdorf[col]),3) for col in self.hausdorf if col.startswith('class')],
29
+ "mean_surface" : [round(np.mean(self.surface[col]),3) for col in self.surface if col.startswith('class')]
30
+ }).transpose()
31
+
32
+ def generate_report(self, loss_plot=True, metric_plot=True, boxplots=True, animation=True):
33
+ fn = os.path.join(self.run_id, 'report', 'SegmentationReport.md')
34
+ os.makedirs(os.path.join(self.run_id, 'report'), exist_ok=True)
35
+ with open(fn, 'w+') as f:
36
+ f.write('# Segmentation Report\n\n')
37
+
38
+ if loss_plot:
39
+ fig = self.plot_loss(self.train_logs, self.metric_logs)
40
+ plt.savefig(os.path.join(self.run_id, 'report', 'loss_and_lr.png'), dpi = 150)
41
+
42
+ with open(fn, 'a') as f:
43
+ f.write('## Loss, LR-Schedule and Key Metric\n')
44
+ f.write('![Loss, LR-Schedule and Key Metric](loss_and_lr.png)\n\n')
45
+
46
+ if metric_plot:
47
+ fig = plt.figure("metrics", (18, 6))
48
+
49
+ ax = plt.subplot(1, 3, 1)
50
+ plt.ylim([0,1])
51
+ plt.title("Mean Dice")
52
+ plt.xlabel("epoch")
53
+ plt.plot(self.metric_logs.index, self.metric_logs.MeanDice)
54
+
55
+ ax = plt.subplot(1, 3, 2)
56
+ plt.title("Mean Hausdorff Distance")
57
+ plt.xlabel("epoch")
58
+ plt.plot(self.metric_logs.index, self.metric_logs.HausdorffDistance)
59
+
60
+ ax = plt.subplot(1, 3, 3)
61
+ plt.title("Mean Surface Distance")
62
+ plt.xlabel("epoch")
63
+ plt.plot(self.metric_logs.index, self.metric_logs.SurfaceDistance)
64
+
65
+ plt.savefig(os.path.join(self.run_id, 'report', 'metrics.png'), dpi = 150)
66
+ fig.clear()
67
+ plt.close()
68
+
69
+ with open(fn, 'a') as f:
70
+ f.write('## Metrics\n')
71
+ f.write('![metrics](metrics.png)\n\n')
72
+
73
+ if boxplots:
74
+ fig = plt.figure("boxplots", (18, 6))
75
+
76
+ ax = plt.subplot(1, 3, 1)
77
+ plt.title("Dice")
78
+ plt.xlabel("class")
79
+ plt.boxplot(self.dice[[col for col in self.dice if col.startswith('class')]])
80
+
81
+ ax = plt.subplot(1, 3, 2)
82
+ plt.title("Hausdorff Distance")
83
+ plt.xlabel("class")
84
+ plt.boxplot(self.hausdorf[[col for col in self.hausdorf if col.startswith('class')]])
85
+
86
+ ax = plt.subplot(1, 3, 3)
87
+ plt.title("Surface Distance")
88
+ plt.xlabel("class")
89
+ plt.boxplot(self.surface[[col for col in self.surface if col.startswith('class')]])
90
+
91
+ plt.savefig(os.path.join(self.run_id, 'report', 'boxplots.png'),dpi = 150)
92
+
93
+ fig.clear()
94
+ plt.close()
95
+
96
+ with open(fn, 'a') as f:
97
+ f.write(f"## Individual metrics\n\n")
98
+ f.write(f"{self.mean_metrics.to_markdown()}\n\n")
99
+ f.write(f"![boxplot](boxplots.png)\n\n")
100
+ if animation:
101
+ self.generate_gif()
102
+ with open(fn, 'a') as f:
103
+ f.write('## Visualization of progress\n')
104
+ f.write('![progress](progress.gif)\n\n')
105
+
106
+ def plot_loss(self, train_logs, metric_logs):
107
+ iteration = train_logs.iteration/sum(train_logs.epoch == 1)
108
+ fig = plt.figure("loss and lr", (12, 6))
109
+
110
+ y_max = max(metric_logs.eval_loss) + 0.5
111
+ if y_max > 3: y_max = 3
112
+
113
+ ax = plt.subplot(1, 2, 1)
114
+ plt.ylim([0,y_max])
115
+ plt.title("Epoch Average Loss")
116
+ plt.xlabel("epoch")
117
+ plt.plot(iteration, train_logs.loss)
118
+ plt.plot(metric_logs.index, metric_logs.eval_loss)
119
+
120
+ ax = plt.subplot(1, 2, 2)
121
+ ax.set_yscale('log')
122
+ plt.title("LR Schedule")
123
+ plt.xlabel("epoch")
124
+ plt.plot(iteration, train_logs.lr)
125
+ return fig
126
+
127
+ def get_arr_from_fig(self, fig, dpi=180):
128
+ buf = io.BytesIO()
129
+ fig.savefig(buf, format="png", dpi=dpi)
130
+ buf.seek(0)
131
+ img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
132
+ buf.close()
133
+ img = cv2.imdecode(img_arr, 1)
134
+ #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
135
+ return img
136
+
137
+ def get_slices(self, im, slices):
138
+ ims = torch.unbind(im[:, :, slices], -1) # extract n slices
139
+ ims = [i.transpose(0,1).flip(0) for i in ims] # rotate slices 90 degrees
140
+ if len(slices) > 4 and len(slices) % 2 == 0:
141
+ n = len(slices) // 2
142
+ ims1 = torch.cat(ims[0:n], 1)
143
+ ims2 = torch.cat(ims[n:], 1)
144
+ return torch.cat([ims1, ims2], 0)
145
+ else:
146
+ return torch.cat(ims, 1) # create tile
147
+
148
+ def plot_images(self, fns, slices, cmap='Greys_r', figsize=15, **kwargs):
149
+ ims = [torch.load(os.path.join(self.out_dir, 'preds', fn)).cpu().argmax(0) for fn in fns]
150
+ ims = [self.get_slices(im, slices) for im in ims]
151
+ ims = torch.cat(ims, 0)
152
+ plt.figure(figsize=(figsize,figsize))
153
+ plt.imshow(ims, cmap=cmap, **kwargs)
154
+ plt.axis('off')
155
+
156
+ def load_segmentation_image(self, fn):
157
+ im = torch.load(fn).cpu().unsqueeze(0)
158
+ im = torch.nn.functional.interpolate(im, (224, 224, 112))
159
+ im = im.argmax(1).squeeze()
160
+ im = self.get_slices(im, slices = (40, 48, 56, 74, 82, 90))
161
+ im = im/im.max() * 255
162
+ return im
163
+
164
+ def generate_gif(self):
165
+ with imageio.get_writer(
166
+ os.path.join(self.run_id,'report','progress.gif'),
167
+ mode='I',
168
+ fps = max(self.train_logs.epoch) // 10) as writer: # make gif 10 seconds
169
+ for epoch in tqdm.tqdm(list(self.train_logs.epoch.unique())):
170
+ seg_fn = os.path.join(self.out_dir, 'preds', f"pred_epoch_{epoch}.pt")
171
+ if os.path.exists(seg_fn): im = self.load_segmentation_image(seg_fn)
172
+
173
+ plt_train_logs = self.train_logs[self.train_logs.epoch <= epoch]
174
+ loss_plt = self.plot_loss(plt_train_logs, self.metric_logs[:epoch])
175
+ loss_fig = self.get_arr_from_fig(loss_plt)[:,:,0]
176
+
177
+ new_shape = im.shape[1], int(loss_fig.shape[0] / loss_fig.shape[1] * im.shape[1])
178
+ loss_fig = cv2.resize(loss_fig, (im.shape[1], im.shape[0]))
179
+
180
+ images = torch.cat([im, torch.tensor(loss_fig)], 0).numpy().astype(np.uint8)
181
+ writer.append_data(images)
182
+
183
+ loss_plt.clear()
184
+ plt.close()
prostate/train.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import munch
4
+ import torch
5
+ import ignite
6
+ import monai
7
+ import shutil
8
+ import pandas as pd
9
+
10
+ from typing import Union, List, Callable
11
+ from ignite.contrib.handlers.tqdm_logger import ProgressBar
12
+ from monai.handlers import (
13
+ CheckpointSaver,
14
+ StatsHandler,
15
+ TensorBoardStatsHandler,
16
+ TensorBoardImageHandler,
17
+ ValidationHandler,
18
+ from_engine,
19
+ MeanDice,
20
+ EarlyStopHandler,
21
+ MetricLogger,
22
+ MetricsSaver
23
+ )
24
+
25
+ from .data import segmentation_dataloaders
26
+ from .model import get_model
27
+ from .optimizer import get_optimizer
28
+ from .loss import get_loss
29
+ from .transforms import get_val_post_transforms
30
+ from .utils import USE_AMP
31
+
32
+ def loss_logger(engine):
33
+ "write loss and lr of each iteration/epoch to file"
34
+ iteration=engine.state.iteration
35
+ epoch=engine.state.epoch
36
+ loss=[o['loss'] for o in engine.state.output]
37
+ loss=sum(loss)/len(loss)
38
+ lr=engine.optimizer.param_groups[0]['lr']
39
+ log_file=os.path.join(engine.config.log_dir, 'train_logs.csv')
40
+ if not os.path.exists(log_file):
41
+ with open(log_file, 'w+') as f:
42
+ f.write('iteration,epoch,loss,lr\n')
43
+ with open(log_file, 'a') as f:
44
+ f.write(f'{iteration},{epoch},{loss},{lr}\n')
45
+
46
+ def metric_logger(engine):
47
+ "write `metrics` after each epoch to file"
48
+ if engine.state.epoch > 1: # only key metric is calcualted in 1st epoch, needs fix
49
+ metric_names=[k for k in engine.state.metrics.keys()]
50
+ metrics=[str(engine.state.metrics[mn]) for mn in metric_names]
51
+ log_file=os.path.join(engine.config.log_dir, 'metric_logs.csv')
52
+ if not os.path.exists(log_file):
53
+ with open(log_file, 'w+') as f:
54
+ f.write(','.join(metric_names) + '\n')
55
+ with open(log_file, 'a') as f:
56
+ f.write(','.join(metrics) + '\n')
57
+
58
+ def pred_logger(engine):
59
+ "save `pred` each time metric improves"
60
+ epoch=engine.state.epoch
61
+ root = os.path.join(engine.config.out_dir, 'preds')
62
+ if not os.path.exists(root):
63
+ os.makedirs(root)
64
+ torch.save(
65
+ engine.state.output[0]['label'],
66
+ os.path.join(root, f'label.pt')
67
+ )
68
+ torch.save(
69
+ engine.state.output[0]['image'],
70
+ os.path.join(root, f'image.pt')
71
+ )
72
+
73
+ if epoch==engine.state.best_metric_epoch:
74
+ torch.save(
75
+ engine.state.output[0]['pred'],
76
+ os.path.join(root, f'pred_epoch_{epoch}.pt')
77
+ )
78
+
79
+
80
+ def get_val_handlers(
81
+ network: torch.nn.Module,
82
+ config: dict
83
+ ) -> list:
84
+ """Create default handlers for model validation
85
+ Args:
86
+ network:
87
+ nn.Module subclass, the model to train
88
+
89
+ Returns:
90
+ a list of default handlers for validation: [
91
+ StatsHandler:
92
+ ???
93
+ TensorBoardStatsHandler:
94
+ Save loss from validation to `config.log_dir`, allow logging with TensorBoard
95
+ CheckpointSaver:
96
+ Save best model to `config.model_dir`
97
+ ]
98
+ """
99
+
100
+ val_handlers=[
101
+ StatsHandler(
102
+ tag_name="metric_logger",
103
+ epoch_print_logger=metric_logger,
104
+ output_transform=lambda x: None
105
+ ),
106
+ StatsHandler(
107
+ tag_name="pred_logger",
108
+ epoch_print_logger=pred_logger,
109
+ output_transform=lambda x: None
110
+ ),
111
+ TensorBoardStatsHandler(
112
+ log_dir=config.log_dir,
113
+ # tag_name="val_mean_dice",
114
+ output_transform=lambda x: None
115
+ ),
116
+ TensorBoardImageHandler(
117
+ log_dir=config.log_dir,
118
+ batch_transform=from_engine(["image", "label"]),
119
+ output_transform=from_engine(["pred"]),
120
+ ),
121
+ CheckpointSaver(
122
+ save_dir=config.model_dir,
123
+ save_dict={f"network_{config.run_id}": network},
124
+ save_key_metric=True
125
+ ),
126
+
127
+ ]
128
+
129
+ return val_handlers
130
+
131
+
132
+ def get_train_handlers(
133
+ evaluator: monai.engines.SupervisedEvaluator,
134
+ config: dict
135
+ ) -> list:
136
+ """Create default handlers for model training
137
+ Args:
138
+ evaluator: an engine of type `monai.engines.SupervisedEvaluator` for evaluations
139
+ every epoch
140
+
141
+ Returns:
142
+ list of default handlers for training: [
143
+ ValidationHandler:
144
+ Allows model validation every epoch
145
+ StatsHandler:
146
+ ???
147
+ TensorBoardStatsHandler:
148
+ Save loss from validation to `config.log_dir`, allow logging with TensorBoard
149
+ ]
150
+ """
151
+
152
+ train_handlers=[
153
+ ValidationHandler(
154
+ validator=evaluator,
155
+ interval=1,
156
+ epoch_level=True
157
+ ),
158
+ StatsHandler(
159
+ tag_name="train_loss",
160
+ output_transform=from_engine(
161
+ ["loss"],
162
+ first=True
163
+ )
164
+ ),
165
+ StatsHandler(
166
+ tag_name='loss_logger',
167
+ iteration_print_logger=loss_logger
168
+ ),
169
+ TensorBoardStatsHandler(
170
+ log_dir=config.log_dir,
171
+ tag_name="train_loss",
172
+ output_transform=from_engine(
173
+ ["loss"],
174
+ first=True
175
+ ),
176
+ )
177
+ ]
178
+
179
+ return train_handlers
180
+
181
+ def get_evaluator(
182
+ config: dict,
183
+ device: torch.device ,
184
+ network: torch.nn.Module,
185
+ val_data_loader: monai.data.dataloader.DataLoader,
186
+ val_post_transforms: monai.transforms.compose.Compose,
187
+ val_handlers: Union[Callable, List]=get_val_handlers
188
+ ) -> monai.engines.SupervisedEvaluator:
189
+
190
+ """Create default evaluator for training of a segmentation model
191
+ Args:
192
+ device:
193
+ torch.cuda.device for model and engine
194
+ network:
195
+ nn.Module subclass, the model to train
196
+ val_data_loader:
197
+ Validation data loader, `monai.data.dataloader.DataLoader` subclass
198
+ val_post_transforms:
199
+ function to create transforms OR composed transforms
200
+ val_handlers:
201
+ function to create handerls OR List of handlers
202
+
203
+ Returns:
204
+ default evaluator for segmentation of type `monai.engines.SupervisedEvaluator`
205
+ """
206
+
207
+ if callable(val_handlers): val_handlers=val_handlers()
208
+
209
+ evaluator=monai.engines.SupervisedEvaluator(
210
+ device=device,
211
+ val_data_loader=val_data_loader,
212
+ network=network,
213
+ inferer=monai.inferers.SlidingWindowInferer(
214
+ roi_size=(96, 96, 96),
215
+ sw_batch_size=4,
216
+ overlap=0.5
217
+ ),
218
+ postprocessing=val_post_transforms,
219
+ key_val_metric={
220
+ "val_mean_dice": MeanDice(
221
+ include_background=False,
222
+ output_transform=from_engine(
223
+ ["pred", "label"]
224
+ )
225
+ )
226
+ },
227
+ val_handlers=val_handlers,
228
+ # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
229
+ amp=USE_AMP,
230
+ )
231
+ evaluator.config=config
232
+ return evaluator
233
+
234
+
235
+ class SegmentationTrainer(monai.engines.SupervisedTrainer):
236
+ "Default Trainer für supervised segmentation task"
237
+ def __init__(self,
238
+ config: dict,
239
+ progress_bar: bool=True,
240
+ early_stopping: bool=True,
241
+ metrics: list=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
242
+ save_latest_metrics: bool=True
243
+ ):
244
+ self.config=config
245
+ self._prepare_dirs()
246
+ self.config.device=torch.device(self.config.device)
247
+
248
+ train_loader, val_loader=segmentation_dataloaders(
249
+ config=config,
250
+ train=True,
251
+ valid=True,
252
+ test=False
253
+ )
254
+ network=get_model(config=config).to(config.device)
255
+ optimizer=get_optimizer(
256
+ network,
257
+ config=config
258
+ )
259
+ loss_fn=get_loss(config=config)
260
+ val_post_transforms=get_val_post_transforms(config=config)
261
+ val_handlers=get_val_handlers(
262
+ network,
263
+ config=config
264
+ )
265
+
266
+ self.evaluator=get_evaluator(
267
+ config=config,
268
+ device=config.device,
269
+ network=network,
270
+ val_data_loader=val_loader,
271
+ val_post_transforms=val_post_transforms,
272
+ val_handlers=val_handlers,
273
+
274
+ )
275
+ train_handlers=get_train_handlers(
276
+ self.evaluator,
277
+ config=config
278
+ )
279
+
280
+ super().__init__(
281
+ device=config.device,
282
+ max_epochs=self.config.training.max_epochs,
283
+ train_data_loader=train_loader,
284
+ network=network,
285
+ optimizer=optimizer,
286
+ loss_function=loss_fn,
287
+ inferer=monai.inferers.SimpleInferer(),
288
+ train_handlers=train_handlers,
289
+ amp=USE_AMP,
290
+ )
291
+
292
+ if early_stopping: self._add_early_stopping()
293
+ if progress_bar: self._add_progress_bars()
294
+
295
+ self.schedulers=[]
296
+ # add different metrics dynamically
297
+ for m in metrics:
298
+ getattr(monai.handlers, m)(
299
+ include_background=False,
300
+ reduction="mean",
301
+ output_transform=from_engine(
302
+ ["pred", "label"]
303
+ )
304
+ ).attach(self.evaluator, m)
305
+
306
+ self._add_metrics_logger()
307
+ # add eval loss to metrics
308
+ self._add_eval_loss()
309
+
310
+ if save_latest_metrics: self._add_metrics_saver()
311
+
312
+
313
+ def _prepare_dirs(self)->None:
314
+ # create run_id, copy config file for reproducibility
315
+ os.makedirs(self.config.run_id, exist_ok=True)
316
+ with open(
317
+ os.path.join(
318
+ self.config.run_id,
319
+ 'config.yaml'
320
+ ), 'w+') as f:
321
+ f.write(yaml.safe_dump(self.config))
322
+
323
+ # delete old log_dir
324
+ if os.path.exists(self.config.log_dir):
325
+ shutil.rmtree(self.config.log_dir)
326
+
327
+ def _add_early_stopping(self) -> None:
328
+ early_stopping=EarlyStopHandler(
329
+ patience=self.config.training.early_stopping_patience,
330
+ min_delta=1e-4,
331
+ score_function=lambda x: x.state.metrics[x.state.key_metric_name],
332
+ trainer=self
333
+ )
334
+ self.evaluator.add_event_handler(
335
+ ignite.engine.Events.COMPLETED,
336
+ early_stopping
337
+ )
338
+
339
+ def _add_metrics_logger(self) -> None:
340
+ self.metric_logger=MetricLogger(
341
+ evaluator=self.evaluator
342
+ )
343
+ self.metric_logger.attach(self)
344
+
345
+ def _add_progress_bars(self) -> None:
346
+ trainer_pbar=ProgressBar()
347
+ evaluator_pbar=ProgressBar(
348
+ colour='green'
349
+ )
350
+ trainer_pbar.attach(
351
+ self,
352
+ output_transform=lambda output:{
353
+ 'loss': torch.tensor(
354
+ [x['loss'] for x in output]
355
+ ).mean()
356
+ }
357
+ )
358
+ evaluator_pbar.attach(self.evaluator)
359
+
360
+ def _add_metrics_saver(self) -> None:
361
+ metric_saver=MetricsSaver(
362
+ save_dir=self.config.out_dir,
363
+ metric_details='*',
364
+ batch_transform=self._get_meta_dict,
365
+ delimiter=','
366
+ )
367
+ metric_saver.attach(self.evaluator)
368
+
369
+ def _add_eval_loss(self)->None:
370
+ # TODO improve by adding this to val handlers
371
+ eval_loss_handler=ignite.metrics.Loss(
372
+ loss_fn=self.loss_function,
373
+ output_transform=lambda output: (
374
+ output[0]['pred'].unsqueeze(0), # add batch dim
375
+ output[0]['label'].argmax(0, keepdim=True).unsqueeze(0) # reverse one-hot, add batch dim
376
+ )
377
+ )
378
+ eval_loss_handler.attach(self.evaluator, 'eval_loss')
379
+
380
+ def _get_meta_dict(self, batch) -> list:
381
+ "Get dict of metadata from engine. Needed as `batch_transform`"
382
+ image_cols=self.config.data.image_cols
383
+ image_name=image_cols[0] if isinstance(image_cols, list) else image_cols
384
+ key=f'{image_name}_meta_dict'
385
+ return [item[key] for item in batch]
386
+
387
+ def load_checkpoint(self, checkpoint=None):
388
+ if not checkpoint:
389
+ # get name of last checkpoint
390
+ checkpoint = os.path.join(
391
+ self.config.model_dir,
392
+ f"network_{self.config.run_id}_key_metric={self.evaluator.state.best_metric:.4f}.pt"
393
+ )
394
+ self.network.load_state_dict(
395
+ torch.load(checkpoint)
396
+ )
397
+
398
+ def run(self, try_resume_from_checkpoint=True) -> None:
399
+ """Run training, if `try_resume_from_checkpoint` tries to
400
+ load previous checkpoint stored at `self.config.model_dir`
401
+ """
402
+
403
+ if try_resume_from_checkpoint:
404
+ checkpoints = [
405
+ os.path.join(
406
+ self.config.model_dir,
407
+ checkpoint_name
408
+ ) for checkpoint_name in os.listdir(
409
+ self.config.model_dir
410
+ ) if self.config.run_id in checkpoint_name
411
+ ]
412
+ try:
413
+ checkpoint = sorted(checkpoints)[-1]
414
+ self.load_checkpoint(checkpoint)
415
+ print(f"resuming from previous checkpoint at {checkpoint}")
416
+ except: pass # train from scratch
417
+
418
+ # train the model
419
+ super().run()
420
+
421
+ # make metrics and losses more accessible
422
+ self.loss={
423
+ "iter": [_iter for _iter, _ in self.metric_logger.loss],
424
+ "loss": [_loss for _, _loss in self.metric_logger.loss],
425
+ "epoch": [_iter // self.state.epoch_length for _iter, _ in self.metric_logger.loss]
426
+ }
427
+
428
+ self.metrics={
429
+ k: [item[1] for item in self.metric_logger.metrics[k]] for k in
430
+ self.evaluator.state.metric_details.keys()
431
+ }
432
+ # pd.DataFrame(self.metrics).to_csv(f"{self.config.out_dir}/metric_logs.csv")
433
+ # pd.DataFrame(self.loss).to_csv(f"{self.config.out_dir}/loss_logs.csv")
434
+
435
+ def fit_one_cycle(self, try_resume_from_checkpoint=True) -> None:
436
+ "Run training using one-cycle-policy"
437
+ assert "FitOneCycle" not in self.schedulers, "FitOneCycle already added"
438
+ fit_one_cycle=monai.handlers.LrScheduleHandler(
439
+ torch.optim.lr_scheduler.OneCycleLR(
440
+ optimizer=self.optimizer,
441
+ max_lr=self.optimizer.param_groups[0]['lr'],
442
+ steps_per_epoch=self.state.epoch_length,
443
+ epochs=self.state.max_epochs
444
+ ),
445
+ epoch_level=False,
446
+ name="FitOneCycle"
447
+ )
448
+ fit_one_cycle.attach(self)
449
+ self.schedulers += ["FitOneCycle"]
450
+
451
+ def reduce_lr_on_plateau(self,
452
+ try_resume_from_checkpoint=True,
453
+ factor=0.1,
454
+ patience=10,
455
+ min_lr=1e-10,
456
+ verbose=True) -> None:
457
+ "Reduce learning rate by `factor` every `patience` epochs if kex_metric does not improve"
458
+ assert "ReduceLROnPlateau" not in self.schedulers, "ReduceLROnPlateau already added"
459
+ reduce_lr_on_plateau=monai.handlers.LrScheduleHandler(
460
+ torch.optim.lr_scheduler.ReduceLROnPlateau(
461
+ optimizer=self.optimizer,
462
+ factor=factor,
463
+ patience=patience,
464
+ min_lr=min_lr,
465
+ verbose=verbose
466
+ ),
467
+ print_lr=True,
468
+ name='ReduceLROnPlateau',
469
+ epoch_level=True,
470
+ step_transform=lambda engine: engine.state.metrics[engine.state.key_metric_name],
471
+ )
472
+ reduce_lr_on_plateau.attach(self.evaluator)
473
+ self.schedulers += ["ReduceLROnPlateau"]
474
+
475
+ def evaluate(self, checkpoint=None, dataloader=None):
476
+ "Run evaluation with best saved checkpoint"
477
+ self.load_checkpoint(checkpoint)
478
+ if dataloader:
479
+ self.evaluator.set_data(dataloader)
480
+ self.evaluator.state.epoch_length=len(dataloader)
481
+ self.evaluator.run()
482
+ print(f"metrics saved to {self.config.out_dir}")
prostate/transforms.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create transforms for training, validation and test dataset
2
+
3
+ ## TODO: Make Transforms more dynamic by directly building from config args
4
+ ## Maybe like this
5
+ ## TFM_NAME=config.transforms.keys()[0]
6
+ ## tfm_fun=getattr(monai.transforms, TFM_NAME)
7
+ ## tmfs+=[tfms_fun(keys=image+cols, **config.transforms[TFM_NAME], prob=prob, mode=mode)
8
+
9
+
10
+ ## ---------- imports ----------
11
+ import os
12
+ # only import of base transforms, others are imported as needed
13
+ from monai.utils.enums import CommonKeys
14
+ from monai.transforms import (
15
+ Activationsd,
16
+ AsDiscreted,
17
+ Compose,
18
+ ConcatItemsd,
19
+ KeepLargestConnectedComponentd,
20
+ LoadImaged,
21
+ EnsureChannelFirstd,
22
+ EnsureTyped,
23
+ SaveImaged,
24
+ ScaleIntensityd,
25
+ NormalizeIntensityd
26
+ )
27
+ # images should be interploated with `bilinear` but masks with `nearest`
28
+
29
+ ## ---------- base transforms ----------
30
+ # applied everytime
31
+ def get_base_transforms(
32
+ config: dict,
33
+ minv: int=0,
34
+ maxv: int=1
35
+ )->list:
36
+
37
+ tfms=[]
38
+ tfms+=[LoadImaged(keys=config.data.image_cols+config.data.label_cols)]
39
+ tfms+=[EnsureChannelFirstd(keys=config.data.image_cols+config.data.label_cols)]
40
+ if config.transforms.spacing:
41
+ from monai.transforms import Spacingd
42
+ tfms+=[
43
+ Spacingd(
44
+ keys=config.data.image_cols+config.data.label_cols,
45
+ pixdim=config.transforms.spacing,
46
+ mode=config.transforms.mode
47
+ )
48
+ ]
49
+ if config.transforms.orientation:
50
+ from monai.transforms import Orientationd
51
+ tfms+=[
52
+ Orientationd(
53
+ keys=config.data.image_cols+config.data.label_cols,
54
+ axcodes=config.transforms.orientation
55
+ )
56
+ ]
57
+ tfms+=[
58
+ ScaleIntensityd(
59
+ keys=config.data.image_cols,
60
+ minv=minv,
61
+ maxv=maxv
62
+ )
63
+ ]
64
+ tfms+=[NormalizeIntensityd(keys=config.data.image_cols)]
65
+ return tfms
66
+
67
+ ## ---------- train transforms ----------
68
+
69
+ def get_train_transforms(config: dict):
70
+ tfms=get_base_transforms(config=config)
71
+
72
+ # ---------- specific transforms for mri ----------
73
+ if 'rand_bias_field' in config.transforms.keys():
74
+ from monai.transforms import RandBiasFieldd
75
+ args=config.transforms.rand_bias_field
76
+ tfms+=[
77
+ RandBiasFieldd(
78
+ keys=config.data.image_cols,
79
+ degree=args['degree'],
80
+ coeff_range=args['coeff_range'],
81
+ prob=config.transforms.prob
82
+ )
83
+ ]
84
+
85
+ if 'rand_gaussian_smooth' in config.transforms.keys():
86
+ from monai.transforms import RandGaussianSmoothd
87
+ args=config.transforms.rand_gaussian_smooth
88
+ tfms+=[
89
+ RandGaussianSmoothd(
90
+ keys=config.data.image_cols,
91
+ sigma_x=args['sigma_x'],
92
+ sigma_y=args['sigma_y'],
93
+ sigma_z=args['sigma_z'],
94
+ prob=config.transforms.prob
95
+ )
96
+ ]
97
+
98
+ if 'rand_gibbs_nose' in config.transforms.keys():
99
+ from monai.transforms import RandGibbsNoised
100
+ args=config.transforms.rand_gibbs_nose
101
+ tfms+=[
102
+ RandGibbsNoised(
103
+ keys=config.data.image_cols,
104
+ alpha=args['alpha'],
105
+ prob=config.transforms.prob
106
+ )
107
+ ]
108
+
109
+ # ---------- affine transforms ----------
110
+
111
+ if 'rand_affine' in config.transforms.keys():
112
+ from monai.transforms import RandAffined
113
+ args=config.transforms.rand_affine
114
+ tfms+=[
115
+ RandAffined(
116
+ keys=config.data.image_cols+config.data.label_cols,
117
+ rotate_range=args['rotate_range'],
118
+ shear_range=args['shear_range'],
119
+ translate_range=args['translate_range'],
120
+ mode=config.transforms.mode,
121
+ prob=config.transforms.prob
122
+ )
123
+ ]
124
+
125
+ if 'rand_rotate90' in config.transforms.keys():
126
+ from monai.transforms import RandRotate90d
127
+ args=config.transforms.rand_rotate90
128
+ tfms+=[
129
+ RandRotate90d(
130
+ keys=config.data.image_cols+config.data.label_cols,
131
+ spatial_axes=args['spatial_axes'],
132
+ prob=config.transforms.prob
133
+ )
134
+ ]
135
+
136
+ if 'rand_rotate' in config.transforms.keys():
137
+ from monai.transforms import RandRotated
138
+ args=config.transforms.rand_rotate
139
+ tfms+=[
140
+ RandRotated(
141
+ keys=config.data.image_cols+config.data.label_cols,
142
+ range_x=args['range_x'],
143
+ range_y=args['range_y'],
144
+ range_z=args['range_z'],
145
+ mode=config.transforms.mode,
146
+ prob=config.transforms.prob
147
+ )
148
+ ]
149
+
150
+ if 'rand_elastic' in config.transforms.keys():
151
+ if config['ndim'] == 3:
152
+ from monai.transforms import Rand3DElasticd as RandElasticd
153
+ elif config['ndim'] == 2:
154
+ from monai.transforms import Rand2DElasticd as RandElasticd
155
+ args=config.transforms.rand_elastic
156
+ tfms+=[
157
+ RandElasticd(
158
+ keys=config.data.image_cols+config.data.label_cols,
159
+ sigma_range=args['sigma_range'],
160
+ magnitude_range=args['magnitude_range'],
161
+ rotate_range=args['rotate_range'],
162
+ shear_range=args['shear_range'],
163
+ translate_range=args['translate_range'],
164
+ mode=config.transforms.mode,
165
+ prob=config.transforms.prob
166
+ )
167
+ ]
168
+
169
+ if 'rand_zoom' in config.transforms.keys():
170
+ from monai.transforms import RandZoomd
171
+ args=config.transforms.rand_zoom
172
+ tfms+=[
173
+ RandZoomd(
174
+ keys=config.data.image_cols+config.data.label_cols,
175
+ min_zoom=args['min'],
176
+ max_zoom=args['max'],
177
+ mode=['area' if x == 'bilinear' else x for x in config.transforms.mode],
178
+ prob=config.transforms.prob
179
+ )
180
+ ]
181
+
182
+ # ---------- random cropping, very effective for large images ----------
183
+ # RandCropByPosNegLabeld is not advisable for data with missing lables
184
+ # e.g., segmentation of carcinomas which are not present on all images
185
+ # thus fallback to RandSpatialCropSamplesd. Completly replacing Cropping
186
+ # by just resizing could be discussed, but I believe it is not beneficial
187
+ # For the first version, this is an ungly hack. For the second version,
188
+ # a better verion for transforms should be written.
189
+
190
+ if 'rand_crop_pos_neg_label' in config.transforms.keys():
191
+ from monai.transforms import RandCropByPosNegLabeld
192
+ args=config.transforms.rand_crop_pos_neg_label
193
+ tfms+=[
194
+ RandCropByPosNegLabeld(
195
+ keys=config.data.image_cols+config.data.label_cols,
196
+ label_key=config.data.label_cols[0],
197
+ spatial_size=args['spatial_size'],
198
+ pos=args['pos'],
199
+ neg=args['neg'],
200
+ num_samples=args['num_samples'],
201
+ image_key=config.data.image_cols[0],
202
+ image_threshold=0,
203
+ )
204
+ ]
205
+
206
+ elif 'rand_spatial_crop_samples' in config.transforms.keys():
207
+ from monai.transforms import RandSpatialCropSamplesd
208
+ args=config.transforms.rand_spatial_crop_samples
209
+ tfms+=[
210
+ RandSpatialCropSamplesd(
211
+ keys=config.data.image_cols+config.data.label_cols,
212
+ roi_size=args['roi_size'],
213
+ random_size=False,
214
+ num_samples=args['num_samples'],
215
+ )
216
+ ]
217
+
218
+ else:
219
+ raise ValueError('Either `rand_crop_pos_neg_label` or `rand_spatial_crop_samples` '\
220
+ 'need to be specified')
221
+
222
+ # ---------- intensity transforms ----------
223
+
224
+ if 'gaussian_noise' in config.transforms.keys():
225
+ from monai.transforms import RandGaussianNoised
226
+ args=config.transforms.gaussian_noise
227
+ tfms+=[
228
+ RandGaussianNoised(
229
+ keys=config.data.image_cols,
230
+ mean=args['mean'],
231
+ std=args['std'],
232
+ prob=config.transforms.prob
233
+ )
234
+ ]
235
+
236
+ if 'shift_intensity' in config.transforms.keys():
237
+ from monai.transforms import RandShiftIntensityd
238
+ args=config.transforms.shift_intensity
239
+ tfms+=[
240
+ RandShiftIntensityd(
241
+ keys=config.data.image_cols,
242
+ offsets=args['offsets'],
243
+ prob=config.transforms.prob
244
+ )
245
+ ]
246
+
247
+ if 'gaussian_sharpen' in config.transforms.keys():
248
+ from monai.transforms import RandGaussianSharpend
249
+ args=config.transforms.gaussian_sharpen
250
+ tfms+=[
251
+ RandGaussianSharpend(
252
+ keys=config.data.image_cols,
253
+ sigma1_x=args['sigma1_x'],
254
+ sigma1_y=args['sigma1_y'],
255
+ sigma1_z=args['sigma1_z'],
256
+ sigma2_x=args['sigma2_x'],
257
+ sigma2_y=args['sigma2_y'],
258
+ sigma2_z=args['sigma2_z'],
259
+ alpha=args['alpha'],
260
+ prob=config.transforms.prob
261
+ )
262
+ ]
263
+
264
+ if 'adjust_contrast' in config.transforms.keys():
265
+ from monai.transforms import RandAdjustContrastd
266
+ args=config.transforms.adjust_contrast
267
+ tfms+=[
268
+ RandAdjustContrastd(
269
+ keys=config.data.image_cols,
270
+ gamma=args['gamma'],
271
+ prob=config.transforms.prob
272
+ )
273
+ ]
274
+
275
+ # Concat mutlisequence data to single Tensors on the ChannelDim
276
+ # Rename images to `CommonKeys.IMAGE` and labels to `CommonKeys.LABELS`
277
+ # for more compatibility with monai.engines
278
+
279
+ tfms+=[
280
+ ConcatItemsd(
281
+ keys=config.data.image_cols,
282
+ name=CommonKeys.IMAGE,
283
+ dim=0
284
+ )
285
+ ]
286
+
287
+ tfms+=[
288
+ ConcatItemsd(
289
+ keys=config.data.label_cols,
290
+ name=CommonKeys.LABEL,
291
+ dim=0
292
+ )
293
+ ]
294
+
295
+ return Compose(tfms)
296
+
297
+ ## ---------- valid transforms ----------
298
+
299
+ def get_val_transforms(config: dict):
300
+ tfms=get_base_transforms(config=config)
301
+ tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
302
+ tfms+=[
303
+ ConcatItemsd(
304
+ keys=config.data.image_cols,
305
+ name=CommonKeys.IMAGE,
306
+ dim=0
307
+ )
308
+ ]
309
+
310
+ tfms+=[
311
+ ConcatItemsd(
312
+ keys=config.data.label_cols,
313
+ name=CommonKeys.LABEL,
314
+ dim=0
315
+ )
316
+ ]
317
+
318
+ return Compose(tfms)
319
+
320
+ ## ---------- test transforms ----------
321
+ # same as valid transforms
322
+
323
+ def get_test_transforms(config: dict):
324
+ tfms=get_base_transforms(config=config)
325
+ tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
326
+ tfms+=[
327
+ ConcatItemsd(
328
+ keys=config.data.image_cols,
329
+ name=CommonKeys.IMAGE,
330
+ dim=0
331
+ )
332
+ ]
333
+
334
+ tfms+=[
335
+ ConcatItemsd(
336
+ keys=config.data.label_cols,
337
+ name=CommonKeys.LABEL,
338
+ dim=0
339
+ )
340
+ ]
341
+
342
+ return Compose(tfms)
343
+
344
+
345
+ def get_val_post_transforms(config: dict):
346
+ tfms=[EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
347
+ AsDiscreted(
348
+ keys=CommonKeys.PRED,
349
+ argmax=True,
350
+ to_onehot=config.model.out_channels,
351
+ num_classes=config.model.out_channels
352
+ ),
353
+ AsDiscreted(
354
+ keys=CommonKeys.LABEL,
355
+ to_onehot=config.model.out_channels,
356
+ num_classes=config.model.out_channels
357
+ ),
358
+ KeepLargestConnectedComponentd(
359
+ keys=CommonKeys.PRED,
360
+ applied_labels=list(range(1, config.model.out_channels))
361
+ ),
362
+ ]
363
+ return Compose(tfms)
prostate/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import monai
5
+ import munch
6
+
7
+ def load_config(fn: str='config.yaml'):
8
+ "Load config from YAML and return a serialized dictionary object"
9
+ with open(fn, 'r') as stream:
10
+ config=yaml.safe_load(stream)
11
+ config=munch.munchify(config)
12
+
13
+ if not config.overwrite:
14
+ i=1
15
+ while os.path.exists(config.run_id+f'_{i}'):
16
+ i+=1
17
+ config.run_id+=f'_{i}'
18
+
19
+ config.out_dir = os.path.join(config.run_id, config.out_dir)
20
+ config.log_dir = os.path.join(config.run_id, config.log_dir)
21
+
22
+ if not isinstance(config.data.image_cols, (tuple, list)):
23
+ config.data.image_cols=[config.data.image_cols]
24
+ if not isinstance(config.data.label_cols, (tuple, list)):
25
+ config.data.label_cols=[config.data.label_cols]
26
+
27
+ config.transforms.mode=('bilinear', ) * len(config.data.image_cols) + \
28
+ ('nearest', ) * len(config.data.label_cols)
29
+ return config
30
+
31
+
32
+ def num_workers():
33
+ "Get max supported workers -2 for multiprocessing"
34
+ import resource
35
+ import multiprocessing
36
+
37
+ # first check for max number of open files allowed on system
38
+ soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
39
+ n_workers=multiprocessing.cpu_count() - 2
40
+ # giving each worker at least 256 open processes should allow them to run smoothly
41
+ max_workers = soft_limit // 256
42
+ if max_workers < n_workers:
43
+ print(
44
+ "Will not use all available workers as number of allowed open files is to small"
45
+ "to ensure smooth multiprocessing. Current limits are:\n"
46
+ f"\t soft_limit: {soft_limit}\n"
47
+ f"\t hard_limit: {hard_limit}\n"
48
+ "try increasing the limits to at least {256*n_workers}."
49
+ "See https://superuser.com/questions/1200539/cannot-increase-open-file-limit-past-4096-ubuntu"
50
+ "for more details"
51
+ )
52
+ return max_workers
53
+
54
+ return n_workers
55
+
56
+ USE_AMP=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False
prostate/viewer.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ipywidgets
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from IPython.display import display
6
+ from itertools import chain, islice
7
+ from ipywidgets import interactive, widgets
8
+
9
+ def _create_label(text:str)->ipywidgets.widgets.Label:
10
+ "Create label widget"
11
+
12
+ label = widgets.Label(
13
+ text,
14
+ layout=widgets.Layout(
15
+ width='100%',
16
+ display='flex',
17
+ justify_content="center"
18
+ )
19
+ )
20
+ return label
21
+
22
+ def _create_slider(
23
+ slider_min: int,
24
+ slider_max: int,
25
+ value: int,
26
+ step: int=1,
27
+ description:str ='',
28
+ continuous_update: bool=True,
29
+ readout: bool=False,
30
+ slider_type: str='IntSlider',
31
+ **kwargs)->ipywidgets.widgets:
32
+ "Create slider widget"
33
+
34
+ slider = getattr(widgets, slider_type)(
35
+ min=slider_min,
36
+ max=slider_max,
37
+ step=step,
38
+ value=value,
39
+ description=description,
40
+ continuous_update=continuous_update,
41
+ readout = readout,
42
+ layout=widgets.Layout(width='99%', min_width='200px'),
43
+ style={'description_width': 'initial'},
44
+ **kwargs
45
+ )
46
+ return slider
47
+
48
+ def _create_button(description:str)->ipywidgets.widgets.Button:
49
+ "Create button widget"
50
+ button = widgets.Button(
51
+ description=description,
52
+ layout=widgets.Layout(
53
+ width='95%',
54
+ margin='5px 5px'
55
+ )
56
+ )
57
+ return button
58
+
59
+ def _create_togglebutton(description: str,
60
+ value: int,
61
+ **kwargs)->ipywidgets.widgets.Button:
62
+ "Create toggle button widget"
63
+ button = widgets.ToggleButton(
64
+ description=description,
65
+ value = value,
66
+ layout=widgets.Layout(
67
+ width='95%',
68
+ margin='5px 5px'
69
+ ), **kwargs
70
+ )
71
+ return button
72
+
73
+
74
+ class BasicViewer():
75
+ """ Base class for viewing TensorDicom3D objects.
76
+
77
+ Args:
78
+ x: main image object to view as rank 3 tensor
79
+ y: either a segmentation mask as as rank 3 tensor or a label as str.
80
+ prediction: a class predicton as str
81
+ description: description of the whole image
82
+ figsize: size of image, passed as plotting argument
83
+ cmap: colormap for the image
84
+ Returns:
85
+ Instance of BasicViewer
86
+ """
87
+
88
+ def __init__(self, x:torch.Tensor, y=None, prediction:str=None, description: str=None,
89
+ figsize=(3, 3), cmap:str='bone'):
90
+ assert x.ndim == 3, f"x.ndim needs to be equal to but is {x.ndim}"
91
+ if isinstance(y, torch.Tensor):
92
+ assert x.shape == y.shape, f"Shapes of x {x.shape} and y {y.shape} do not match"
93
+ self.x=x
94
+ self.y=y
95
+ self.prediction=prediction
96
+ self.description=description
97
+ self.figsize=figsize
98
+ self.cmap=cmap
99
+ self.with_mask = isinstance(y, torch.Tensor)
100
+ self.slice_range = (1, len(x)) # len(x) == im.shape[0]
101
+
102
+ def _plot_slice(self, im_slice, with_mask, px_range):
103
+ "Plot slice of image"
104
+ fig, ax = plt.subplots(1, 1, figsize=self.figsize)
105
+ ax.imshow(self.x[im_slice-1, :, :].clip(*px_range), cmap=self.cmap)
106
+ if isinstance(self.y, (torch.Tensor)) and with_mask:
107
+ ax.imshow(self.y[im_slice-1, :, :], cmap='jet', alpha = 0.25)
108
+ plt.axis('off')
109
+ ax.set_xticks([])
110
+ ax.set_yticks([])
111
+ plt.show()
112
+
113
+ def _create_image_box(self, figsize):
114
+ "Create widget items, order them in item_box and generate view box"
115
+ items = []
116
+
117
+ if self.description: plot_description = _create_label(self.description)
118
+
119
+ if isinstance(self.y, str):
120
+ label = f'{self.y} | {self.prediction}' if self.prediction else self.y
121
+ if self.prediction:
122
+ font_color = 'green' if self.y == self.prediction else 'red'
123
+ y_label = _create_label(r'\(\color{' + font_color + '} {' + label + '}\)')
124
+ else:
125
+ y_label = _create_label(label)
126
+ else: y_label = _create_label(' ')
127
+
128
+ slice_slider = _create_slider(
129
+ slider_min = min(self.slice_range),
130
+ slider_max = max(self.slice_range),
131
+ value = max(self.slice_range)//2,
132
+ readout = True)
133
+
134
+ toggle_mask_button = _create_togglebutton('Show Mask', True)
135
+
136
+ range_slider = _create_slider(
137
+ slider_min = self.x.min().numpy(),
138
+ slider_max = self.x.max().numpy(),
139
+ value = [self.x.min().numpy(), self.x.max().numpy()],
140
+ slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
141
+ step = 0.01 if torch.is_floating_point(self.x) else 1,
142
+ readout=True)
143
+
144
+ image_output = widgets.interactive_output(
145
+ f = self._plot_slice,
146
+ controls = {'im_slice': slice_slider,
147
+ 'with_mask': toggle_mask_button,
148
+ 'px_range': range_slider})
149
+
150
+ image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
151
+ image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
152
+
153
+ if self.description: items.append(plot_description)
154
+ items.append(y_label)
155
+ items.append(range_slider)
156
+ items.append(image_output)
157
+ if isinstance(self.y, torch.Tensor):
158
+ slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
159
+ items.append(slice_slider)
160
+
161
+ image_box=widgets.VBox(
162
+ items,
163
+ layout = widgets.Layout(
164
+ border = 'none',
165
+ margin = '10px 5px 0px 0px',
166
+ padding = '5px'))
167
+
168
+ return image_box
169
+
170
+ def _generate_views(self):
171
+ image_box = self._create_image_box(self.figsize)
172
+ self.box = widgets.HBox(children=[image_box])
173
+
174
+ @property
175
+ def image_box(self):
176
+ return self._create_image_box(self.figsize)
177
+
178
+ def show(self):
179
+ self._generate_views()
180
+ plt.style.use('default')
181
+ display(self.box)
182
+
183
+
184
+ class DicomExplorer(BasicViewer):
185
+ """ DICOM viewer for basic image analysis inside iPython notebooks.
186
+ Can display a single 3D volume together with a segmentation mask, a histogram
187
+ of voxel/pixel values and some summary statistics.
188
+ Allows simple windowing by clipping the pixel/voxel values to a region, which
189
+ can be manually specified.
190
+
191
+ """
192
+
193
+ vbox_layout = widgets.Layout(
194
+ margin = '10px 5px 5px 5px',
195
+ padding = '5px',
196
+ display='flex',
197
+ flex_flow='column',
198
+ align_items='center',
199
+ min_width = '250px')
200
+
201
+ def _plot_hist(self, px_range):
202
+ x = self.x.numpy().flatten()
203
+ fig, ax = plt.subplots(figsize=self.figsize)
204
+ N, bins, patches = plt.hist(x, 100, color='grey')
205
+ lwr = int(px_range[0] * 100/max(x))
206
+ upr = int(np.ceil(px_range[1] * 100/max(x)))
207
+
208
+ for i in range(0,lwr):
209
+ patches[i].set_facecolor('grey' if lwr > 0 else 'darkblue')
210
+ for i in range(lwr, upr):
211
+ patches[i].set_facecolor('darkblue')
212
+ for i in range(upr,100):
213
+ patches[i].set_facecolor('grey' if upr < 100 else 'darkblue')
214
+
215
+ plt.show()
216
+
217
+ def _image_summary(self, px_range):
218
+ x = self.x.clip(*px_range)
219
+
220
+ diffs = x - x.mean()
221
+ var = torch.mean(torch.pow(diffs, 2.0))
222
+ std = torch.pow(var, 0.5)
223
+ zscores = diffs / std
224
+ skews = torch.mean(torch.pow(zscores, 3.0))
225
+ kurt = torch.mean(torch.pow(zscores, 4.0)) - 3.0
226
+
227
+ table = f'Statistics:\n' + \
228
+ f' Mean px value: {x.mean()} \n' + \
229
+ f' Std of px values: {x.std()} \n' + \
230
+ f' Min px value: {x.min()} \n' + \
231
+ f' Max px value: {x.max()} \n' + \
232
+ f' Median px value: {x.median()} \n' + \
233
+ f' Skewness: {skews} \n' + \
234
+ f' Kurtosis: {kurt} \n\n' + \
235
+ f'Tensor properties \n' + \
236
+ f' Tensor shape: {tuple(x.shape)}\n' + \
237
+ f' Tensor dtype: {x.dtype}'
238
+ print(table)
239
+
240
+ def _generate_views(self):
241
+
242
+ slice_slider = _create_slider(
243
+ slider_min = min(self.slice_range),
244
+ slider_max = max(self.slice_range),
245
+ value = max(self.slice_range)//2,
246
+ readout = True)
247
+
248
+ toggle_mask_button = _create_togglebutton('Show Mask', True)
249
+
250
+ range_slider = _create_slider(
251
+ slider_min = self.x.min().numpy(),
252
+ slider_max = self.x.max().numpy(),
253
+ value = [self.x.min().numpy(), self.x.max().numpy()],
254
+ continuous_update=False,
255
+ slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
256
+ step = 0.01 if torch.is_floating_point(self.x) else 1)
257
+
258
+ image_output = widgets.interactive_output(
259
+ f = self._plot_slice,
260
+ controls = {'im_slice': slice_slider,
261
+ 'with_mask': toggle_mask_button,
262
+ 'px_range': range_slider})
263
+
264
+ image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
265
+ image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
266
+
267
+ if isinstance(self.y, torch.Tensor):
268
+ slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
269
+
270
+ hist_output = widgets.interactive_output(
271
+ f = self._plot_hist,
272
+ controls = {'px_range': range_slider})
273
+
274
+ hist_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
275
+ hist_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
276
+
277
+ toggle_mask_button = _create_togglebutton('Show Mask', True)
278
+
279
+ table_output = widgets.interactive_output(
280
+ f = self._image_summary,
281
+ controls = {'px_range': range_slider})
282
+
283
+ table_box = widgets.VBox([table_output], layout=self.vbox_layout)
284
+
285
+ hist_box = widgets.VBox(
286
+ [hist_output, range_slider],
287
+ layout=self.vbox_layout)
288
+
289
+ image_box = widgets.VBox(
290
+ [image_output, slice_slider],
291
+ layout=self.vbox_layout)
292
+
293
+ self.box = widgets.HBox(
294
+ [image_box, hist_box, table_box],
295
+ layout = widgets.Layout(
296
+ border = 'solid 1px lightgrey',
297
+ margin = '10px 5px 0px 0px',
298
+ padding = '5px',
299
+ width = f'{self.figsize[1]*2 + 3}in'))
300
+
301
+
302
+ class ListViewer(object):
303
+ """ Display multipple images with their masks or labels/predictions.
304
+ Arguments:
305
+ x (tuple, list): Tensor objects to view
306
+ y (tuple, list): Tensor objects (in case of segmentation task) or class labels as string.
307
+ predictions (str): Class predictions
308
+ cmap: colormap for display of `x`
309
+ max_n: maximum number of items to display
310
+ """
311
+
312
+ def __init__(self, x:(list, tuple), y=None, prediction:str=None, description: str=None,
313
+ figsize=(4, 4), cmap:str='bone', max_n = 9):
314
+ self.slice_range = (1, len(x))
315
+ x = x[0:max_n]
316
+ if y: y = y[0:max_n]
317
+ self.x=x
318
+ self.y=y
319
+ self.prediction=prediction
320
+ self.description=description
321
+ self.figsize=figsize
322
+ self.cmap=cmap
323
+ self.max_n=max_n
324
+
325
+ def _generate_views(self):
326
+ n_images = len(self.x)
327
+ image_grid, image_list = [], []
328
+
329
+ for i in range(0, n_images):
330
+ image = self.x[i]
331
+ mask = self.y[i] if isinstance(self.y, list) else None
332
+ pred = self.prediction[i] if self.prediction else None
333
+
334
+ image_list.append(
335
+ BasicViewer(
336
+ x = image,
337
+ y = mask,
338
+ prediction = pred,
339
+ figsize = self.figsize,
340
+ cmap = self.cmap)
341
+ .image_box)
342
+
343
+ if (i+1) % np.ceil(np.sqrt(n_images)) == 0 or i == n_images - 1:
344
+ image_grid.append(widgets.HBox(image_list))
345
+ image_list = []
346
+
347
+ self.box = widgets.VBox(children=image_grid)
348
+
349
+ def show(self):
350
+ self._generate_views()
351
+ plt.style.use('default')
352
+ display(self.box)