osbm commited on
Commit
bb96fc5
·
1 Parent(s): 8ba2459

Delete prostate

Browse files
prostate/data.py DELETED
@@ -1,229 +0,0 @@
1
- # create dataloaders form csv file
2
-
3
- ## ---------- imports ----------
4
- import os
5
- import torch
6
- import shutil
7
- import numpy as np
8
- import pandas as pd
9
- from typing import Union
10
- from monai.utils import first
11
- from functools import partial
12
- from collections import namedtuple
13
- from monai.data import DataLoader as MonaiDataLoader
14
-
15
- from . import transforms
16
- from .utils import num_workers
17
-
18
-
19
- def import_dataset(config: dict):
20
- if config.data.dataset_type == 'persistent':
21
- from monai.data import PersistentDataset
22
- if os.path.exists(config.data.cache_dir):
23
- shutil.rmtree(config.data.cache_dir) # rm previous cache DS
24
- os.makedirs(config.data.cache_dir, exist_ok = True)
25
- Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir)
26
- elif config.data.dataset_type == 'cache':
27
- from monai.data import CacheDataset
28
- raise NotImplementedError('CacheDataset not yet implemented')
29
- else:
30
- from monai.data import Dataset
31
- return Dataset
32
-
33
-
34
- class DataLoader(MonaiDataLoader):
35
- "overwrite monai DataLoader for enhanced viewing capabilities"
36
-
37
- def show_batch(self,
38
- image_key: str='image',
39
- label_key: str='label',
40
- image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2),
41
- label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)):
42
- """Args:
43
- image_key: dict key name for image to view
44
- label_key: dict kex name for corresponding label. Can be a tensor or str
45
- image_transform: transform input before it is passed to the viewer to ensure
46
- ndim of the image is equal to 3 and image is oriented correctly
47
- label_transform: transform labels before passed to the viewer, to ensure
48
- segmentations masks have same shape and orientations as images. Should be
49
- identity function of labels are str.
50
- """
51
- from .viewer import ListViewer
52
-
53
- batch = first(self)
54
- image = torch.unbind(batch[image_key], 0)
55
- label = torch.unbind(batch[label_key], 0)
56
-
57
- ListViewer([image_transform(im) for im in image],
58
- [label_transform(im) for im in label]).show()
59
-
60
- # TODO
61
- ## Work with 3 dataloaders
62
-
63
- def segmentation_dataloaders(config: dict,
64
- train: bool = None,
65
- valid: bool = None,
66
- test: bool = None,
67
- ):
68
- """Create segmentation dataloaders
69
- Args:
70
- config: config file
71
- train: whether to return a train DataLoader
72
- valid: whether to return a valid DataLoader
73
- test: whether to return a test DateLoader
74
- Args from config:
75
- data_dir: base directory for the data
76
- csv_name: path to csv file containing filenames and paths
77
- image_cols: columns in csv containing path to images
78
- label_cols: columns in csv containing path to label files
79
- dataset_type: PersistentDataset, CacheDataset and Dataset are supported
80
- cache_dir: cache directory to be used by PersistentDataset
81
- batch_size: batch size for training. Valid and test are always 1
82
- debug: run with reduced number of images
83
- Returns:
84
- list of:
85
- train_loader: DataLoader (optional, if train==True)
86
- valid_loader: DataLoader (optional, if valid==True)
87
- test_loader: DataLoader (optional, if test==True)
88
- """
89
-
90
- ## parse needed rguments from config
91
- if train is None: train = config.data.train
92
- if valid is None: valid = config.data.valid
93
- if test is None: test = config.data.test
94
-
95
- data_dir = config.data.data_dir
96
- train_csv = config.data.train_csv
97
- valid_csv = config.data.valid_csv
98
- test_csv = config.data.test_csv
99
- image_cols = config.data.image_cols
100
- label_cols = config.data.label_cols
101
- dataset_type = config.data.dataset_type
102
- cache_dir = config.data.cache_dir
103
- batch_size = config.data.batch_size
104
- debug = config.debug
105
-
106
- ## ---------- data dicts ----------
107
-
108
- # first a global data dict, containing only the filepath from image_cols and label_cols is created. For this,
109
- # the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an
110
- # individual dict, as expected by monai
111
-
112
- if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols]
113
- if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols]
114
-
115
- train_df = pd.read_csv(train_csv)
116
- valid_df = pd.read_csv(valid_csv)
117
- test_df = pd.read_csv(test_csv)
118
- if debug:
119
- train_df = train_df.sample(25)
120
- valid_df = valid_df.sample(5)
121
-
122
- train_df['split']='train'
123
- valid_df['split']='valid'
124
- test_df['split']='test'
125
- whole_df = []
126
- if train: whole_df += [train_df]
127
- if valid: whole_df += [valid_df]
128
- if test: whole_df += [test_df]
129
- df = pd.concat(whole_df)
130
- cols = image_cols + label_cols
131
- for col in cols:
132
- # create absolute file name from relative fn in df and data_dir
133
- df[col] = [os.path.join(data_dir, fn) for fn in df[col]]
134
- if not os.path.exists(list(df[col])[0]):
135
- raise FileNotFoundError(list(df[col])[0])
136
-
137
-
138
- data_dict = [dict(row[1]) for row in df[cols].iterrows()]
139
- # data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer.
140
- # The data_dict looks like this:
141
- # [
142
- # {'image_col_1': 'data_dir/path/to/image1',
143
- # 'image_col_2': 'data_dir/path/to/image2'
144
- # 'label_col_1': 'data_dir/path/to/label1},
145
- # {'image_col_1': 'data_dir/path/to/image1',
146
- # 'image_col_2': 'data_dir/path/to/image2'
147
- # 'label_col_1': 'data_dir/path/to/label1},
148
- # ...]
149
- # Filename should now be absolute or relative to working directory
150
-
151
- # now we create separate data dicts for train, valid and test data respectively
152
- assert train or test or valid, 'No dataset type is specified (train/valid or test)'
153
-
154
- if test:
155
- test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test')))
156
-
157
- if valid:
158
- val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid')))
159
-
160
- if train:
161
- train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train')))
162
-
163
- # transforms are specified in transforms.py and are just loaded here
164
- if train: train_transforms = transforms.get_train_transforms(config)
165
- if valid: val_transforms = transforms.get_val_transforms(config)
166
- if test: test_transforms = transforms.get_test_transforms(config)
167
-
168
-
169
- ## ---------- construct dataloaders ----------
170
- Dataset=import_dataset(config)
171
- data_loaders = []
172
- if train:
173
- train_ds = Dataset(
174
- data=train_files,
175
- transform=train_transforms
176
- )
177
- train_loader = DataLoader(
178
- train_ds,
179
- batch_size=batch_size,
180
- num_workers=num_workers(),
181
- shuffle=True
182
- )
183
- data_loaders.append(train_loader)
184
-
185
- if valid:
186
- val_ds = Dataset(
187
- data=val_files,
188
- transform=val_transforms
189
- )
190
- val_loader = DataLoader(
191
- val_ds,
192
- batch_size=1,
193
- num_workers=num_workers(),
194
- shuffle=False
195
- )
196
- data_loaders.append(val_loader)
197
-
198
- if test:
199
- test_ds = Dataset(
200
- data=test_files,
201
- transform=test_transforms
202
- )
203
- test_loader = DataLoader(
204
- test_ds,
205
- batch_size=1,
206
- num_workers=num_workers(),
207
- shuffle=False
208
- )
209
- data_loaders.append(test_loader)
210
-
211
- # if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders,
212
- # so it is clear which DataLoader is train/valid or test
213
-
214
- if len(data_loaders) == 1:
215
- return data_loaders[0]
216
- else:
217
- DataLoaders = namedtuple(
218
- 'DataLoaders',
219
- # create str with specification of loader type if train and test are true but
220
- # valid is false string will be 'train test'
221
- ' '.join(
222
- [
223
- 'train' if train else '',
224
- 'valid' if valid else '',
225
- 'test' if test else ''
226
- ]
227
- ).strip()
228
- )
229
- return DataLoaders(*data_loaders)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/loss.py DELETED
@@ -1,23 +0,0 @@
1
- import monai
2
- from .utils import load_config
3
-
4
-
5
- def get_loss(config: dict):
6
- """Create a loss function of `type` with specific keyword arguments from config.
7
- Example:
8
-
9
- config.loss
10
- >>> {'DiceCELoss': {'include_background': False, 'softmax': True, 'to_onehot_y': True}}
11
-
12
- get_loss(config)
13
- >>> DiceCELoss(
14
- >>> (dice): DiceLoss()
15
- >>> (cross_entropy): CrossEntropyLoss()
16
- >>> )
17
-
18
- """
19
- loss_type = list(config.loss.keys())[0]
20
- loss_config = config.loss[loss_type]
21
- loss_fun = getattr(monai.losses, loss_type)
22
- loss = loss_fun(**loss_config)
23
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/model.py DELETED
@@ -1,16 +0,0 @@
1
- # create a standard UNet
2
-
3
- from monai.networks.nets import UNet
4
-
5
- def get_model(config: dict):
6
- return UNet(
7
- spatial_dims=config.ndim,
8
- in_channels=len(config.data.image_cols),
9
- out_channels=config.model.out_channels,
10
- channels=config.model.channels,
11
- strides=config.model.strides,
12
- num_res_units=config.model.num_res_units,
13
- act=config.model.act,
14
- norm=config.model.norm,
15
- dropout=config.model.dropout,
16
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/optimizer.py DELETED
@@ -1,37 +0,0 @@
1
- import torch
2
- import monai
3
- from .utils import load_config
4
-
5
-
6
- def get_optimizer(model: torch.nn.Module,
7
- config: dict):
8
- """Create an optimizer of `type` with specific keyword arguments from config.
9
- Example:
10
-
11
- config.optimizer
12
- >>> {'Novograd': {'lr': 0.001, 'weight_decay': 0.01}}
13
-
14
- get_optimizer(model, config)
15
- >>> Novograd (
16
- >>> Parameter Group 0
17
- >>> amsgrad: False
18
- >>> betas: (0.9, 0.999)
19
- >>> eps: 1e-08
20
- >>> grad_averaging: False
21
- >>> lr: 0.0001
22
- >>> weight_decay: 0.001
23
- >>> )
24
-
25
- """
26
- optimizer_type = list(config.optimizer.keys())[0]
27
- opt_config = config.optimizer[optimizer_type]
28
- if hasattr(torch.optim, optimizer_type):
29
- optimizer_fun = getattr(torch.optim, optimizer_type)
30
- elif hasattr(monai.optimizers, optimizer_type):
31
- optimizer_fun = getattr(monai.optimizers, optimizer_type)
32
- else:
33
- raise ValueError(f'Optimizer {optimizer_type} not found')
34
- optimizer = optimizer_fun(model.parameters(), **opt_config)
35
- return optimizer
36
-
37
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/report.py DELETED
@@ -1,184 +0,0 @@
1
- import os
2
- import io
3
- import cv2
4
- import tqdm
5
- import torch
6
- import imageio
7
- import numpy as np
8
- import pandas as pd
9
- import matplotlib.pyplot as plt
10
-
11
- class ReportGenerator():
12
- "Generate markdown document, summarizing the training"
13
-
14
- def __init__(self, run_id, out_dir=None, log_dir=None):
15
-
16
- self.run_id, self.out_dir, self.log_dir = run_id, out_dir, log_dir
17
-
18
- if log_dir:
19
- self.train_logs = pd.read_csv(os.path.join(log_dir, 'train_logs.csv'))
20
- self.metric_logs = pd.read_csv(os.path.join(log_dir, 'metric_logs.csv'))
21
- if out_dir:
22
- self.dice = pd.read_csv(os.path.join(out_dir, 'MeanDice_raw.csv'))
23
- self.hausdorf = pd.read_csv(os.path.join(out_dir, 'HausdorffDistance_raw.csv'))
24
- self.surface = pd.read_csv(os.path.join(out_dir, 'SurfaceDistance_raw.csv'))
25
-
26
- self.mean_metrics = pd.DataFrame(
27
- {"mean_dice" : [round(np.mean(self.dice[col]),3) for col in self.dice if col.startswith('class')],
28
- "mean_hausdorf" : [round(np.mean(self.hausdorf[col]),3) for col in self.hausdorf if col.startswith('class')],
29
- "mean_surface" : [round(np.mean(self.surface[col]),3) for col in self.surface if col.startswith('class')]
30
- }).transpose()
31
-
32
- def generate_report(self, loss_plot=True, metric_plot=True, boxplots=True, animation=True):
33
- fn = os.path.join(self.run_id, 'report', 'SegmentationReport.md')
34
- os.makedirs(os.path.join(self.run_id, 'report'), exist_ok=True)
35
- with open(fn, 'w+') as f:
36
- f.write('# Segmentation Report\n\n')
37
-
38
- if loss_plot:
39
- fig = self.plot_loss(self.train_logs, self.metric_logs)
40
- plt.savefig(os.path.join(self.run_id, 'report', 'loss_and_lr.png'), dpi = 150)
41
-
42
- with open(fn, 'a') as f:
43
- f.write('## Loss, LR-Schedule and Key Metric\n')
44
- f.write('![Loss, LR-Schedule and Key Metric](loss_and_lr.png)\n\n')
45
-
46
- if metric_plot:
47
- fig = plt.figure("metrics", (18, 6))
48
-
49
- ax = plt.subplot(1, 3, 1)
50
- plt.ylim([0,1])
51
- plt.title("Mean Dice")
52
- plt.xlabel("epoch")
53
- plt.plot(self.metric_logs.index, self.metric_logs.MeanDice)
54
-
55
- ax = plt.subplot(1, 3, 2)
56
- plt.title("Mean Hausdorff Distance")
57
- plt.xlabel("epoch")
58
- plt.plot(self.metric_logs.index, self.metric_logs.HausdorffDistance)
59
-
60
- ax = plt.subplot(1, 3, 3)
61
- plt.title("Mean Surface Distance")
62
- plt.xlabel("epoch")
63
- plt.plot(self.metric_logs.index, self.metric_logs.SurfaceDistance)
64
-
65
- plt.savefig(os.path.join(self.run_id, 'report', 'metrics.png'), dpi = 150)
66
- fig.clear()
67
- plt.close()
68
-
69
- with open(fn, 'a') as f:
70
- f.write('## Metrics\n')
71
- f.write('![metrics](metrics.png)\n\n')
72
-
73
- if boxplots:
74
- fig = plt.figure("boxplots", (18, 6))
75
-
76
- ax = plt.subplot(1, 3, 1)
77
- plt.title("Dice")
78
- plt.xlabel("class")
79
- plt.boxplot(self.dice[[col for col in self.dice if col.startswith('class')]])
80
-
81
- ax = plt.subplot(1, 3, 2)
82
- plt.title("Hausdorff Distance")
83
- plt.xlabel("class")
84
- plt.boxplot(self.hausdorf[[col for col in self.hausdorf if col.startswith('class')]])
85
-
86
- ax = plt.subplot(1, 3, 3)
87
- plt.title("Surface Distance")
88
- plt.xlabel("class")
89
- plt.boxplot(self.surface[[col for col in self.surface if col.startswith('class')]])
90
-
91
- plt.savefig(os.path.join(self.run_id, 'report', 'boxplots.png'),dpi = 150)
92
-
93
- fig.clear()
94
- plt.close()
95
-
96
- with open(fn, 'a') as f:
97
- f.write(f"## Individual metrics\n\n")
98
- f.write(f"{self.mean_metrics.to_markdown()}\n\n")
99
- f.write(f"![boxplot](boxplots.png)\n\n")
100
- if animation:
101
- self.generate_gif()
102
- with open(fn, 'a') as f:
103
- f.write('## Visualization of progress\n')
104
- f.write('![progress](progress.gif)\n\n')
105
-
106
- def plot_loss(self, train_logs, metric_logs):
107
- iteration = train_logs.iteration/sum(train_logs.epoch == 1)
108
- fig = plt.figure("loss and lr", (12, 6))
109
-
110
- y_max = max(metric_logs.eval_loss) + 0.5
111
- if y_max > 3: y_max = 3
112
-
113
- ax = plt.subplot(1, 2, 1)
114
- plt.ylim([0,y_max])
115
- plt.title("Epoch Average Loss")
116
- plt.xlabel("epoch")
117
- plt.plot(iteration, train_logs.loss)
118
- plt.plot(metric_logs.index, metric_logs.eval_loss)
119
-
120
- ax = plt.subplot(1, 2, 2)
121
- ax.set_yscale('log')
122
- plt.title("LR Schedule")
123
- plt.xlabel("epoch")
124
- plt.plot(iteration, train_logs.lr)
125
- return fig
126
-
127
- def get_arr_from_fig(self, fig, dpi=180):
128
- buf = io.BytesIO()
129
- fig.savefig(buf, format="png", dpi=dpi)
130
- buf.seek(0)
131
- img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
132
- buf.close()
133
- img = cv2.imdecode(img_arr, 1)
134
- #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
135
- return img
136
-
137
- def get_slices(self, im, slices):
138
- ims = torch.unbind(im[:, :, slices], -1) # extract n slices
139
- ims = [i.transpose(0,1).flip(0) for i in ims] # rotate slices 90 degrees
140
- if len(slices) > 4 and len(slices) % 2 == 0:
141
- n = len(slices) // 2
142
- ims1 = torch.cat(ims[0:n], 1)
143
- ims2 = torch.cat(ims[n:], 1)
144
- return torch.cat([ims1, ims2], 0)
145
- else:
146
- return torch.cat(ims, 1) # create tile
147
-
148
- def plot_images(self, fns, slices, cmap='Greys_r', figsize=15, **kwargs):
149
- ims = [torch.load(os.path.join(self.out_dir, 'preds', fn)).cpu().argmax(0) for fn in fns]
150
- ims = [self.get_slices(im, slices) for im in ims]
151
- ims = torch.cat(ims, 0)
152
- plt.figure(figsize=(figsize,figsize))
153
- plt.imshow(ims, cmap=cmap, **kwargs)
154
- plt.axis('off')
155
-
156
- def load_segmentation_image(self, fn):
157
- im = torch.load(fn).cpu().unsqueeze(0)
158
- im = torch.nn.functional.interpolate(im, (224, 224, 112))
159
- im = im.argmax(1).squeeze()
160
- im = self.get_slices(im, slices = (40, 48, 56, 74, 82, 90))
161
- im = im/im.max() * 255
162
- return im
163
-
164
- def generate_gif(self):
165
- with imageio.get_writer(
166
- os.path.join(self.run_id,'report','progress.gif'),
167
- mode='I',
168
- fps = max(self.train_logs.epoch) // 10) as writer: # make gif 10 seconds
169
- for epoch in tqdm.tqdm(list(self.train_logs.epoch.unique())):
170
- seg_fn = os.path.join(self.out_dir, 'preds', f"pred_epoch_{epoch}.pt")
171
- if os.path.exists(seg_fn): im = self.load_segmentation_image(seg_fn)
172
-
173
- plt_train_logs = self.train_logs[self.train_logs.epoch <= epoch]
174
- loss_plt = self.plot_loss(plt_train_logs, self.metric_logs[:epoch])
175
- loss_fig = self.get_arr_from_fig(loss_plt)[:,:,0]
176
-
177
- new_shape = im.shape[1], int(loss_fig.shape[0] / loss_fig.shape[1] * im.shape[1])
178
- loss_fig = cv2.resize(loss_fig, (im.shape[1], im.shape[0]))
179
-
180
- images = torch.cat([im, torch.tensor(loss_fig)], 0).numpy().astype(np.uint8)
181
- writer.append_data(images)
182
-
183
- loss_plt.clear()
184
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/train.py DELETED
@@ -1,482 +0,0 @@
1
- import os
2
- import yaml
3
- import munch
4
- import torch
5
- import ignite
6
- import monai
7
- import shutil
8
- import pandas as pd
9
-
10
- from typing import Union, List, Callable
11
- from ignite.contrib.handlers.tqdm_logger import ProgressBar
12
- from monai.handlers import (
13
- CheckpointSaver,
14
- StatsHandler,
15
- TensorBoardStatsHandler,
16
- TensorBoardImageHandler,
17
- ValidationHandler,
18
- from_engine,
19
- MeanDice,
20
- EarlyStopHandler,
21
- MetricLogger,
22
- MetricsSaver
23
- )
24
-
25
- from .data import segmentation_dataloaders
26
- from .model import get_model
27
- from .optimizer import get_optimizer
28
- from .loss import get_loss
29
- from .transforms import get_val_post_transforms
30
- from .utils import USE_AMP
31
-
32
- def loss_logger(engine):
33
- "write loss and lr of each iteration/epoch to file"
34
- iteration=engine.state.iteration
35
- epoch=engine.state.epoch
36
- loss=[o['loss'] for o in engine.state.output]
37
- loss=sum(loss)/len(loss)
38
- lr=engine.optimizer.param_groups[0]['lr']
39
- log_file=os.path.join(engine.config.log_dir, 'train_logs.csv')
40
- if not os.path.exists(log_file):
41
- with open(log_file, 'w+') as f:
42
- f.write('iteration,epoch,loss,lr\n')
43
- with open(log_file, 'a') as f:
44
- f.write(f'{iteration},{epoch},{loss},{lr}\n')
45
-
46
- def metric_logger(engine):
47
- "write `metrics` after each epoch to file"
48
- if engine.state.epoch > 1: # only key metric is calcualted in 1st epoch, needs fix
49
- metric_names=[k for k in engine.state.metrics.keys()]
50
- metrics=[str(engine.state.metrics[mn]) for mn in metric_names]
51
- log_file=os.path.join(engine.config.log_dir, 'metric_logs.csv')
52
- if not os.path.exists(log_file):
53
- with open(log_file, 'w+') as f:
54
- f.write(','.join(metric_names) + '\n')
55
- with open(log_file, 'a') as f:
56
- f.write(','.join(metrics) + '\n')
57
-
58
- def pred_logger(engine):
59
- "save `pred` each time metric improves"
60
- epoch=engine.state.epoch
61
- root = os.path.join(engine.config.out_dir, 'preds')
62
- if not os.path.exists(root):
63
- os.makedirs(root)
64
- torch.save(
65
- engine.state.output[0]['label'],
66
- os.path.join(root, f'label.pt')
67
- )
68
- torch.save(
69
- engine.state.output[0]['image'],
70
- os.path.join(root, f'image.pt')
71
- )
72
-
73
- if epoch==engine.state.best_metric_epoch:
74
- torch.save(
75
- engine.state.output[0]['pred'],
76
- os.path.join(root, f'pred_epoch_{epoch}.pt')
77
- )
78
-
79
-
80
- def get_val_handlers(
81
- network: torch.nn.Module,
82
- config: dict
83
- ) -> list:
84
- """Create default handlers for model validation
85
- Args:
86
- network:
87
- nn.Module subclass, the model to train
88
-
89
- Returns:
90
- a list of default handlers for validation: [
91
- StatsHandler:
92
- ???
93
- TensorBoardStatsHandler:
94
- Save loss from validation to `config.log_dir`, allow logging with TensorBoard
95
- CheckpointSaver:
96
- Save best model to `config.model_dir`
97
- ]
98
- """
99
-
100
- val_handlers=[
101
- StatsHandler(
102
- tag_name="metric_logger",
103
- epoch_print_logger=metric_logger,
104
- output_transform=lambda x: None
105
- ),
106
- StatsHandler(
107
- tag_name="pred_logger",
108
- epoch_print_logger=pred_logger,
109
- output_transform=lambda x: None
110
- ),
111
- TensorBoardStatsHandler(
112
- log_dir=config.log_dir,
113
- # tag_name="val_mean_dice",
114
- output_transform=lambda x: None
115
- ),
116
- TensorBoardImageHandler(
117
- log_dir=config.log_dir,
118
- batch_transform=from_engine(["image", "label"]),
119
- output_transform=from_engine(["pred"]),
120
- ),
121
- CheckpointSaver(
122
- save_dir=config.model_dir,
123
- save_dict={f"network_{config.run_id}": network},
124
- save_key_metric=True
125
- ),
126
-
127
- ]
128
-
129
- return val_handlers
130
-
131
-
132
- def get_train_handlers(
133
- evaluator: monai.engines.SupervisedEvaluator,
134
- config: dict
135
- ) -> list:
136
- """Create default handlers for model training
137
- Args:
138
- evaluator: an engine of type `monai.engines.SupervisedEvaluator` for evaluations
139
- every epoch
140
-
141
- Returns:
142
- list of default handlers for training: [
143
- ValidationHandler:
144
- Allows model validation every epoch
145
- StatsHandler:
146
- ???
147
- TensorBoardStatsHandler:
148
- Save loss from validation to `config.log_dir`, allow logging with TensorBoard
149
- ]
150
- """
151
-
152
- train_handlers=[
153
- ValidationHandler(
154
- validator=evaluator,
155
- interval=1,
156
- epoch_level=True
157
- ),
158
- StatsHandler(
159
- tag_name="train_loss",
160
- output_transform=from_engine(
161
- ["loss"],
162
- first=True
163
- )
164
- ),
165
- StatsHandler(
166
- tag_name='loss_logger',
167
- iteration_print_logger=loss_logger
168
- ),
169
- TensorBoardStatsHandler(
170
- log_dir=config.log_dir,
171
- tag_name="train_loss",
172
- output_transform=from_engine(
173
- ["loss"],
174
- first=True
175
- ),
176
- )
177
- ]
178
-
179
- return train_handlers
180
-
181
- def get_evaluator(
182
- config: dict,
183
- device: torch.device ,
184
- network: torch.nn.Module,
185
- val_data_loader: monai.data.dataloader.DataLoader,
186
- val_post_transforms: monai.transforms.compose.Compose,
187
- val_handlers: Union[Callable, List]=get_val_handlers
188
- ) -> monai.engines.SupervisedEvaluator:
189
-
190
- """Create default evaluator for training of a segmentation model
191
- Args:
192
- device:
193
- torch.cuda.device for model and engine
194
- network:
195
- nn.Module subclass, the model to train
196
- val_data_loader:
197
- Validation data loader, `monai.data.dataloader.DataLoader` subclass
198
- val_post_transforms:
199
- function to create transforms OR composed transforms
200
- val_handlers:
201
- function to create handerls OR List of handlers
202
-
203
- Returns:
204
- default evaluator for segmentation of type `monai.engines.SupervisedEvaluator`
205
- """
206
-
207
- if callable(val_handlers): val_handlers=val_handlers()
208
-
209
- evaluator=monai.engines.SupervisedEvaluator(
210
- device=device,
211
- val_data_loader=val_data_loader,
212
- network=network,
213
- inferer=monai.inferers.SlidingWindowInferer(
214
- roi_size=(96, 96, 96),
215
- sw_batch_size=4,
216
- overlap=0.5
217
- ),
218
- postprocessing=val_post_transforms,
219
- key_val_metric={
220
- "val_mean_dice": MeanDice(
221
- include_background=False,
222
- output_transform=from_engine(
223
- ["pred", "label"]
224
- )
225
- )
226
- },
227
- val_handlers=val_handlers,
228
- # if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
229
- amp=USE_AMP,
230
- )
231
- evaluator.config=config
232
- return evaluator
233
-
234
-
235
- class SegmentationTrainer(monai.engines.SupervisedTrainer):
236
- "Default Trainer für supervised segmentation task"
237
- def __init__(self,
238
- config: dict,
239
- progress_bar: bool=True,
240
- early_stopping: bool=True,
241
- metrics: list=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
242
- save_latest_metrics: bool=True
243
- ):
244
- self.config=config
245
- self._prepare_dirs()
246
- self.config.device=torch.device(self.config.device)
247
-
248
- train_loader, val_loader=segmentation_dataloaders(
249
- config=config,
250
- train=True,
251
- valid=True,
252
- test=False
253
- )
254
- network=get_model(config=config).to(config.device)
255
- optimizer=get_optimizer(
256
- network,
257
- config=config
258
- )
259
- loss_fn=get_loss(config=config)
260
- val_post_transforms=get_val_post_transforms(config=config)
261
- val_handlers=get_val_handlers(
262
- network,
263
- config=config
264
- )
265
-
266
- self.evaluator=get_evaluator(
267
- config=config,
268
- device=config.device,
269
- network=network,
270
- val_data_loader=val_loader,
271
- val_post_transforms=val_post_transforms,
272
- val_handlers=val_handlers,
273
-
274
- )
275
- train_handlers=get_train_handlers(
276
- self.evaluator,
277
- config=config
278
- )
279
-
280
- super().__init__(
281
- device=config.device,
282
- max_epochs=self.config.training.max_epochs,
283
- train_data_loader=train_loader,
284
- network=network,
285
- optimizer=optimizer,
286
- loss_function=loss_fn,
287
- inferer=monai.inferers.SimpleInferer(),
288
- train_handlers=train_handlers,
289
- amp=USE_AMP,
290
- )
291
-
292
- if early_stopping: self._add_early_stopping()
293
- if progress_bar: self._add_progress_bars()
294
-
295
- self.schedulers=[]
296
- # add different metrics dynamically
297
- for m in metrics:
298
- getattr(monai.handlers, m)(
299
- include_background=False,
300
- reduction="mean",
301
- output_transform=from_engine(
302
- ["pred", "label"]
303
- )
304
- ).attach(self.evaluator, m)
305
-
306
- self._add_metrics_logger()
307
- # add eval loss to metrics
308
- self._add_eval_loss()
309
-
310
- if save_latest_metrics: self._add_metrics_saver()
311
-
312
-
313
- def _prepare_dirs(self)->None:
314
- # create run_id, copy config file for reproducibility
315
- os.makedirs(self.config.run_id, exist_ok=True)
316
- with open(
317
- os.path.join(
318
- self.config.run_id,
319
- 'config.yaml'
320
- ), 'w+') as f:
321
- f.write(yaml.safe_dump(self.config))
322
-
323
- # delete old log_dir
324
- if os.path.exists(self.config.log_dir):
325
- shutil.rmtree(self.config.log_dir)
326
-
327
- def _add_early_stopping(self) -> None:
328
- early_stopping=EarlyStopHandler(
329
- patience=self.config.training.early_stopping_patience,
330
- min_delta=1e-4,
331
- score_function=lambda x: x.state.metrics[x.state.key_metric_name],
332
- trainer=self
333
- )
334
- self.evaluator.add_event_handler(
335
- ignite.engine.Events.COMPLETED,
336
- early_stopping
337
- )
338
-
339
- def _add_metrics_logger(self) -> None:
340
- self.metric_logger=MetricLogger(
341
- evaluator=self.evaluator
342
- )
343
- self.metric_logger.attach(self)
344
-
345
- def _add_progress_bars(self) -> None:
346
- trainer_pbar=ProgressBar()
347
- evaluator_pbar=ProgressBar(
348
- colour='green'
349
- )
350
- trainer_pbar.attach(
351
- self,
352
- output_transform=lambda output:{
353
- 'loss': torch.tensor(
354
- [x['loss'] for x in output]
355
- ).mean()
356
- }
357
- )
358
- evaluator_pbar.attach(self.evaluator)
359
-
360
- def _add_metrics_saver(self) -> None:
361
- metric_saver=MetricsSaver(
362
- save_dir=self.config.out_dir,
363
- metric_details='*',
364
- batch_transform=self._get_meta_dict,
365
- delimiter=','
366
- )
367
- metric_saver.attach(self.evaluator)
368
-
369
- def _add_eval_loss(self)->None:
370
- # TODO improve by adding this to val handlers
371
- eval_loss_handler=ignite.metrics.Loss(
372
- loss_fn=self.loss_function,
373
- output_transform=lambda output: (
374
- output[0]['pred'].unsqueeze(0), # add batch dim
375
- output[0]['label'].argmax(0, keepdim=True).unsqueeze(0) # reverse one-hot, add batch dim
376
- )
377
- )
378
- eval_loss_handler.attach(self.evaluator, 'eval_loss')
379
-
380
- def _get_meta_dict(self, batch) -> list:
381
- "Get dict of metadata from engine. Needed as `batch_transform`"
382
- image_cols=self.config.data.image_cols
383
- image_name=image_cols[0] if isinstance(image_cols, list) else image_cols
384
- key=f'{image_name}_meta_dict'
385
- return [item[key] for item in batch]
386
-
387
- def load_checkpoint(self, checkpoint=None):
388
- if not checkpoint:
389
- # get name of last checkpoint
390
- checkpoint = os.path.join(
391
- self.config.model_dir,
392
- f"network_{self.config.run_id}_key_metric={self.evaluator.state.best_metric:.4f}.pt"
393
- )
394
- self.network.load_state_dict(
395
- torch.load(checkpoint)
396
- )
397
-
398
- def run(self, try_resume_from_checkpoint=True) -> None:
399
- """Run training, if `try_resume_from_checkpoint` tries to
400
- load previous checkpoint stored at `self.config.model_dir`
401
- """
402
-
403
- if try_resume_from_checkpoint:
404
- checkpoints = [
405
- os.path.join(
406
- self.config.model_dir,
407
- checkpoint_name
408
- ) for checkpoint_name in os.listdir(
409
- self.config.model_dir
410
- ) if self.config.run_id in checkpoint_name
411
- ]
412
- try:
413
- checkpoint = sorted(checkpoints)[-1]
414
- self.load_checkpoint(checkpoint)
415
- print(f"resuming from previous checkpoint at {checkpoint}")
416
- except: pass # train from scratch
417
-
418
- # train the model
419
- super().run()
420
-
421
- # make metrics and losses more accessible
422
- self.loss={
423
- "iter": [_iter for _iter, _ in self.metric_logger.loss],
424
- "loss": [_loss for _, _loss in self.metric_logger.loss],
425
- "epoch": [_iter // self.state.epoch_length for _iter, _ in self.metric_logger.loss]
426
- }
427
-
428
- self.metrics={
429
- k: [item[1] for item in self.metric_logger.metrics[k]] for k in
430
- self.evaluator.state.metric_details.keys()
431
- }
432
- # pd.DataFrame(self.metrics).to_csv(f"{self.config.out_dir}/metric_logs.csv")
433
- # pd.DataFrame(self.loss).to_csv(f"{self.config.out_dir}/loss_logs.csv")
434
-
435
- def fit_one_cycle(self, try_resume_from_checkpoint=True) -> None:
436
- "Run training using one-cycle-policy"
437
- assert "FitOneCycle" not in self.schedulers, "FitOneCycle already added"
438
- fit_one_cycle=monai.handlers.LrScheduleHandler(
439
- torch.optim.lr_scheduler.OneCycleLR(
440
- optimizer=self.optimizer,
441
- max_lr=self.optimizer.param_groups[0]['lr'],
442
- steps_per_epoch=self.state.epoch_length,
443
- epochs=self.state.max_epochs
444
- ),
445
- epoch_level=False,
446
- name="FitOneCycle"
447
- )
448
- fit_one_cycle.attach(self)
449
- self.schedulers += ["FitOneCycle"]
450
-
451
- def reduce_lr_on_plateau(self,
452
- try_resume_from_checkpoint=True,
453
- factor=0.1,
454
- patience=10,
455
- min_lr=1e-10,
456
- verbose=True) -> None:
457
- "Reduce learning rate by `factor` every `patience` epochs if kex_metric does not improve"
458
- assert "ReduceLROnPlateau" not in self.schedulers, "ReduceLROnPlateau already added"
459
- reduce_lr_on_plateau=monai.handlers.LrScheduleHandler(
460
- torch.optim.lr_scheduler.ReduceLROnPlateau(
461
- optimizer=self.optimizer,
462
- factor=factor,
463
- patience=patience,
464
- min_lr=min_lr,
465
- verbose=verbose
466
- ),
467
- print_lr=True,
468
- name='ReduceLROnPlateau',
469
- epoch_level=True,
470
- step_transform=lambda engine: engine.state.metrics[engine.state.key_metric_name],
471
- )
472
- reduce_lr_on_plateau.attach(self.evaluator)
473
- self.schedulers += ["ReduceLROnPlateau"]
474
-
475
- def evaluate(self, checkpoint=None, dataloader=None):
476
- "Run evaluation with best saved checkpoint"
477
- self.load_checkpoint(checkpoint)
478
- if dataloader:
479
- self.evaluator.set_data(dataloader)
480
- self.evaluator.state.epoch_length=len(dataloader)
481
- self.evaluator.run()
482
- print(f"metrics saved to {self.config.out_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/transforms.py DELETED
@@ -1,363 +0,0 @@
1
- # create transforms for training, validation and test dataset
2
-
3
- ## TODO: Make Transforms more dynamic by directly building from config args
4
- ## Maybe like this
5
- ## TFM_NAME=config.transforms.keys()[0]
6
- ## tfm_fun=getattr(monai.transforms, TFM_NAME)
7
- ## tmfs+=[tfms_fun(keys=image+cols, **config.transforms[TFM_NAME], prob=prob, mode=mode)
8
-
9
-
10
- ## ---------- imports ----------
11
- import os
12
- # only import of base transforms, others are imported as needed
13
- from monai.utils.enums import CommonKeys
14
- from monai.transforms import (
15
- Activationsd,
16
- AsDiscreted,
17
- Compose,
18
- ConcatItemsd,
19
- KeepLargestConnectedComponentd,
20
- LoadImaged,
21
- EnsureChannelFirstd,
22
- EnsureTyped,
23
- SaveImaged,
24
- ScaleIntensityd,
25
- NormalizeIntensityd
26
- )
27
- # images should be interploated with `bilinear` but masks with `nearest`
28
-
29
- ## ---------- base transforms ----------
30
- # applied everytime
31
- def get_base_transforms(
32
- config: dict,
33
- minv: int=0,
34
- maxv: int=1
35
- )->list:
36
-
37
- tfms=[]
38
- tfms+=[LoadImaged(keys=config.data.image_cols+config.data.label_cols)]
39
- tfms+=[EnsureChannelFirstd(keys=config.data.image_cols+config.data.label_cols)]
40
- if config.transforms.spacing:
41
- from monai.transforms import Spacingd
42
- tfms+=[
43
- Spacingd(
44
- keys=config.data.image_cols+config.data.label_cols,
45
- pixdim=config.transforms.spacing,
46
- mode=config.transforms.mode
47
- )
48
- ]
49
- if config.transforms.orientation:
50
- from monai.transforms import Orientationd
51
- tfms+=[
52
- Orientationd(
53
- keys=config.data.image_cols+config.data.label_cols,
54
- axcodes=config.transforms.orientation
55
- )
56
- ]
57
- tfms+=[
58
- ScaleIntensityd(
59
- keys=config.data.image_cols,
60
- minv=minv,
61
- maxv=maxv
62
- )
63
- ]
64
- tfms+=[NormalizeIntensityd(keys=config.data.image_cols)]
65
- return tfms
66
-
67
- ## ---------- train transforms ----------
68
-
69
- def get_train_transforms(config: dict):
70
- tfms=get_base_transforms(config=config)
71
-
72
- # ---------- specific transforms for mri ----------
73
- if 'rand_bias_field' in config.transforms.keys():
74
- from monai.transforms import RandBiasFieldd
75
- args=config.transforms.rand_bias_field
76
- tfms+=[
77
- RandBiasFieldd(
78
- keys=config.data.image_cols,
79
- degree=args['degree'],
80
- coeff_range=args['coeff_range'],
81
- prob=config.transforms.prob
82
- )
83
- ]
84
-
85
- if 'rand_gaussian_smooth' in config.transforms.keys():
86
- from monai.transforms import RandGaussianSmoothd
87
- args=config.transforms.rand_gaussian_smooth
88
- tfms+=[
89
- RandGaussianSmoothd(
90
- keys=config.data.image_cols,
91
- sigma_x=args['sigma_x'],
92
- sigma_y=args['sigma_y'],
93
- sigma_z=args['sigma_z'],
94
- prob=config.transforms.prob
95
- )
96
- ]
97
-
98
- if 'rand_gibbs_nose' in config.transforms.keys():
99
- from monai.transforms import RandGibbsNoised
100
- args=config.transforms.rand_gibbs_nose
101
- tfms+=[
102
- RandGibbsNoised(
103
- keys=config.data.image_cols,
104
- alpha=args['alpha'],
105
- prob=config.transforms.prob
106
- )
107
- ]
108
-
109
- # ---------- affine transforms ----------
110
-
111
- if 'rand_affine' in config.transforms.keys():
112
- from monai.transforms import RandAffined
113
- args=config.transforms.rand_affine
114
- tfms+=[
115
- RandAffined(
116
- keys=config.data.image_cols+config.data.label_cols,
117
- rotate_range=args['rotate_range'],
118
- shear_range=args['shear_range'],
119
- translate_range=args['translate_range'],
120
- mode=config.transforms.mode,
121
- prob=config.transforms.prob
122
- )
123
- ]
124
-
125
- if 'rand_rotate90' in config.transforms.keys():
126
- from monai.transforms import RandRotate90d
127
- args=config.transforms.rand_rotate90
128
- tfms+=[
129
- RandRotate90d(
130
- keys=config.data.image_cols+config.data.label_cols,
131
- spatial_axes=args['spatial_axes'],
132
- prob=config.transforms.prob
133
- )
134
- ]
135
-
136
- if 'rand_rotate' in config.transforms.keys():
137
- from monai.transforms import RandRotated
138
- args=config.transforms.rand_rotate
139
- tfms+=[
140
- RandRotated(
141
- keys=config.data.image_cols+config.data.label_cols,
142
- range_x=args['range_x'],
143
- range_y=args['range_y'],
144
- range_z=args['range_z'],
145
- mode=config.transforms.mode,
146
- prob=config.transforms.prob
147
- )
148
- ]
149
-
150
- if 'rand_elastic' in config.transforms.keys():
151
- if config['ndim'] == 3:
152
- from monai.transforms import Rand3DElasticd as RandElasticd
153
- elif config['ndim'] == 2:
154
- from monai.transforms import Rand2DElasticd as RandElasticd
155
- args=config.transforms.rand_elastic
156
- tfms+=[
157
- RandElasticd(
158
- keys=config.data.image_cols+config.data.label_cols,
159
- sigma_range=args['sigma_range'],
160
- magnitude_range=args['magnitude_range'],
161
- rotate_range=args['rotate_range'],
162
- shear_range=args['shear_range'],
163
- translate_range=args['translate_range'],
164
- mode=config.transforms.mode,
165
- prob=config.transforms.prob
166
- )
167
- ]
168
-
169
- if 'rand_zoom' in config.transforms.keys():
170
- from monai.transforms import RandZoomd
171
- args=config.transforms.rand_zoom
172
- tfms+=[
173
- RandZoomd(
174
- keys=config.data.image_cols+config.data.label_cols,
175
- min_zoom=args['min'],
176
- max_zoom=args['max'],
177
- mode=['area' if x == 'bilinear' else x for x in config.transforms.mode],
178
- prob=config.transforms.prob
179
- )
180
- ]
181
-
182
- # ---------- random cropping, very effective for large images ----------
183
- # RandCropByPosNegLabeld is not advisable for data with missing lables
184
- # e.g., segmentation of carcinomas which are not present on all images
185
- # thus fallback to RandSpatialCropSamplesd. Completly replacing Cropping
186
- # by just resizing could be discussed, but I believe it is not beneficial
187
- # For the first version, this is an ungly hack. For the second version,
188
- # a better verion for transforms should be written.
189
-
190
- if 'rand_crop_pos_neg_label' in config.transforms.keys():
191
- from monai.transforms import RandCropByPosNegLabeld
192
- args=config.transforms.rand_crop_pos_neg_label
193
- tfms+=[
194
- RandCropByPosNegLabeld(
195
- keys=config.data.image_cols+config.data.label_cols,
196
- label_key=config.data.label_cols[0],
197
- spatial_size=args['spatial_size'],
198
- pos=args['pos'],
199
- neg=args['neg'],
200
- num_samples=args['num_samples'],
201
- image_key=config.data.image_cols[0],
202
- image_threshold=0,
203
- )
204
- ]
205
-
206
- elif 'rand_spatial_crop_samples' in config.transforms.keys():
207
- from monai.transforms import RandSpatialCropSamplesd
208
- args=config.transforms.rand_spatial_crop_samples
209
- tfms+=[
210
- RandSpatialCropSamplesd(
211
- keys=config.data.image_cols+config.data.label_cols,
212
- roi_size=args['roi_size'],
213
- random_size=False,
214
- num_samples=args['num_samples'],
215
- )
216
- ]
217
-
218
- else:
219
- raise ValueError('Either `rand_crop_pos_neg_label` or `rand_spatial_crop_samples` '\
220
- 'need to be specified')
221
-
222
- # ---------- intensity transforms ----------
223
-
224
- if 'gaussian_noise' in config.transforms.keys():
225
- from monai.transforms import RandGaussianNoised
226
- args=config.transforms.gaussian_noise
227
- tfms+=[
228
- RandGaussianNoised(
229
- keys=config.data.image_cols,
230
- mean=args['mean'],
231
- std=args['std'],
232
- prob=config.transforms.prob
233
- )
234
- ]
235
-
236
- if 'shift_intensity' in config.transforms.keys():
237
- from monai.transforms import RandShiftIntensityd
238
- args=config.transforms.shift_intensity
239
- tfms+=[
240
- RandShiftIntensityd(
241
- keys=config.data.image_cols,
242
- offsets=args['offsets'],
243
- prob=config.transforms.prob
244
- )
245
- ]
246
-
247
- if 'gaussian_sharpen' in config.transforms.keys():
248
- from monai.transforms import RandGaussianSharpend
249
- args=config.transforms.gaussian_sharpen
250
- tfms+=[
251
- RandGaussianSharpend(
252
- keys=config.data.image_cols,
253
- sigma1_x=args['sigma1_x'],
254
- sigma1_y=args['sigma1_y'],
255
- sigma1_z=args['sigma1_z'],
256
- sigma2_x=args['sigma2_x'],
257
- sigma2_y=args['sigma2_y'],
258
- sigma2_z=args['sigma2_z'],
259
- alpha=args['alpha'],
260
- prob=config.transforms.prob
261
- )
262
- ]
263
-
264
- if 'adjust_contrast' in config.transforms.keys():
265
- from monai.transforms import RandAdjustContrastd
266
- args=config.transforms.adjust_contrast
267
- tfms+=[
268
- RandAdjustContrastd(
269
- keys=config.data.image_cols,
270
- gamma=args['gamma'],
271
- prob=config.transforms.prob
272
- )
273
- ]
274
-
275
- # Concat mutlisequence data to single Tensors on the ChannelDim
276
- # Rename images to `CommonKeys.IMAGE` and labels to `CommonKeys.LABELS`
277
- # for more compatibility with monai.engines
278
-
279
- tfms+=[
280
- ConcatItemsd(
281
- keys=config.data.image_cols,
282
- name=CommonKeys.IMAGE,
283
- dim=0
284
- )
285
- ]
286
-
287
- tfms+=[
288
- ConcatItemsd(
289
- keys=config.data.label_cols,
290
- name=CommonKeys.LABEL,
291
- dim=0
292
- )
293
- ]
294
-
295
- return Compose(tfms)
296
-
297
- ## ---------- valid transforms ----------
298
-
299
- def get_val_transforms(config: dict):
300
- tfms=get_base_transforms(config=config)
301
- tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
302
- tfms+=[
303
- ConcatItemsd(
304
- keys=config.data.image_cols,
305
- name=CommonKeys.IMAGE,
306
- dim=0
307
- )
308
- ]
309
-
310
- tfms+=[
311
- ConcatItemsd(
312
- keys=config.data.label_cols,
313
- name=CommonKeys.LABEL,
314
- dim=0
315
- )
316
- ]
317
-
318
- return Compose(tfms)
319
-
320
- ## ---------- test transforms ----------
321
- # same as valid transforms
322
-
323
- def get_test_transforms(config: dict):
324
- tfms=get_base_transforms(config=config)
325
- tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
326
- tfms+=[
327
- ConcatItemsd(
328
- keys=config.data.image_cols,
329
- name=CommonKeys.IMAGE,
330
- dim=0
331
- )
332
- ]
333
-
334
- tfms+=[
335
- ConcatItemsd(
336
- keys=config.data.label_cols,
337
- name=CommonKeys.LABEL,
338
- dim=0
339
- )
340
- ]
341
-
342
- return Compose(tfms)
343
-
344
-
345
- def get_val_post_transforms(config: dict):
346
- tfms=[EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
347
- AsDiscreted(
348
- keys=CommonKeys.PRED,
349
- argmax=True,
350
- to_onehot=config.model.out_channels,
351
- num_classes=config.model.out_channels
352
- ),
353
- AsDiscreted(
354
- keys=CommonKeys.LABEL,
355
- to_onehot=config.model.out_channels,
356
- num_classes=config.model.out_channels
357
- ),
358
- KeepLargestConnectedComponentd(
359
- keys=CommonKeys.PRED,
360
- applied_labels=list(range(1, config.model.out_channels))
361
- ),
362
- ]
363
- return Compose(tfms)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/utils.py DELETED
@@ -1,56 +0,0 @@
1
- import os
2
- import yaml
3
- import torch
4
- import monai
5
- import munch
6
-
7
- def load_config(fn: str='config.yaml'):
8
- "Load config from YAML and return a serialized dictionary object"
9
- with open(fn, 'r') as stream:
10
- config=yaml.safe_load(stream)
11
- config=munch.munchify(config)
12
-
13
- if not config.overwrite:
14
- i=1
15
- while os.path.exists(config.run_id+f'_{i}'):
16
- i+=1
17
- config.run_id+=f'_{i}'
18
-
19
- config.out_dir = os.path.join(config.run_id, config.out_dir)
20
- config.log_dir = os.path.join(config.run_id, config.log_dir)
21
-
22
- if not isinstance(config.data.image_cols, (tuple, list)):
23
- config.data.image_cols=[config.data.image_cols]
24
- if not isinstance(config.data.label_cols, (tuple, list)):
25
- config.data.label_cols=[config.data.label_cols]
26
-
27
- config.transforms.mode=('bilinear', ) * len(config.data.image_cols) + \
28
- ('nearest', ) * len(config.data.label_cols)
29
- return config
30
-
31
-
32
- def num_workers():
33
- "Get max supported workers -2 for multiprocessing"
34
- import resource
35
- import multiprocessing
36
-
37
- # first check for max number of open files allowed on system
38
- soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
39
- n_workers=multiprocessing.cpu_count() - 2
40
- # giving each worker at least 256 open processes should allow them to run smoothly
41
- max_workers = soft_limit // 256
42
- if max_workers < n_workers:
43
- print(
44
- "Will not use all available workers as number of allowed open files is to small"
45
- "to ensure smooth multiprocessing. Current limits are:\n"
46
- f"\t soft_limit: {soft_limit}\n"
47
- f"\t hard_limit: {hard_limit}\n"
48
- "try increasing the limits to at least {256*n_workers}."
49
- "See https://superuser.com/questions/1200539/cannot-increase-open-file-limit-past-4096-ubuntu"
50
- "for more details"
51
- )
52
- return max_workers
53
-
54
- return n_workers
55
-
56
- USE_AMP=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prostate/viewer.py DELETED
@@ -1,352 +0,0 @@
1
- import torch
2
- import ipywidgets
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- from IPython.display import display
6
- from itertools import chain, islice
7
- from ipywidgets import interactive, widgets
8
-
9
- def _create_label(text:str)->ipywidgets.widgets.Label:
10
- "Create label widget"
11
-
12
- label = widgets.Label(
13
- text,
14
- layout=widgets.Layout(
15
- width='100%',
16
- display='flex',
17
- justify_content="center"
18
- )
19
- )
20
- return label
21
-
22
- def _create_slider(
23
- slider_min: int,
24
- slider_max: int,
25
- value: int,
26
- step: int=1,
27
- description:str ='',
28
- continuous_update: bool=True,
29
- readout: bool=False,
30
- slider_type: str='IntSlider',
31
- **kwargs)->ipywidgets.widgets:
32
- "Create slider widget"
33
-
34
- slider = getattr(widgets, slider_type)(
35
- min=slider_min,
36
- max=slider_max,
37
- step=step,
38
- value=value,
39
- description=description,
40
- continuous_update=continuous_update,
41
- readout = readout,
42
- layout=widgets.Layout(width='99%', min_width='200px'),
43
- style={'description_width': 'initial'},
44
- **kwargs
45
- )
46
- return slider
47
-
48
- def _create_button(description:str)->ipywidgets.widgets.Button:
49
- "Create button widget"
50
- button = widgets.Button(
51
- description=description,
52
- layout=widgets.Layout(
53
- width='95%',
54
- margin='5px 5px'
55
- )
56
- )
57
- return button
58
-
59
- def _create_togglebutton(description: str,
60
- value: int,
61
- **kwargs)->ipywidgets.widgets.Button:
62
- "Create toggle button widget"
63
- button = widgets.ToggleButton(
64
- description=description,
65
- value = value,
66
- layout=widgets.Layout(
67
- width='95%',
68
- margin='5px 5px'
69
- ), **kwargs
70
- )
71
- return button
72
-
73
-
74
- class BasicViewer():
75
- """ Base class for viewing TensorDicom3D objects.
76
-
77
- Args:
78
- x: main image object to view as rank 3 tensor
79
- y: either a segmentation mask as as rank 3 tensor or a label as str.
80
- prediction: a class predicton as str
81
- description: description of the whole image
82
- figsize: size of image, passed as plotting argument
83
- cmap: colormap for the image
84
- Returns:
85
- Instance of BasicViewer
86
- """
87
-
88
- def __init__(self, x:torch.Tensor, y=None, prediction:str=None, description: str=None,
89
- figsize=(3, 3), cmap:str='bone'):
90
- assert x.ndim == 3, f"x.ndim needs to be equal to but is {x.ndim}"
91
- if isinstance(y, torch.Tensor):
92
- assert x.shape == y.shape, f"Shapes of x {x.shape} and y {y.shape} do not match"
93
- self.x=x
94
- self.y=y
95
- self.prediction=prediction
96
- self.description=description
97
- self.figsize=figsize
98
- self.cmap=cmap
99
- self.with_mask = isinstance(y, torch.Tensor)
100
- self.slice_range = (1, len(x)) # len(x) == im.shape[0]
101
-
102
- def _plot_slice(self, im_slice, with_mask, px_range):
103
- "Plot slice of image"
104
- fig, ax = plt.subplots(1, 1, figsize=self.figsize)
105
- ax.imshow(self.x[im_slice-1, :, :].clip(*px_range), cmap=self.cmap)
106
- if isinstance(self.y, (torch.Tensor)) and with_mask:
107
- ax.imshow(self.y[im_slice-1, :, :], cmap='jet', alpha = 0.25)
108
- plt.axis('off')
109
- ax.set_xticks([])
110
- ax.set_yticks([])
111
- plt.show()
112
-
113
- def _create_image_box(self, figsize):
114
- "Create widget items, order them in item_box and generate view box"
115
- items = []
116
-
117
- if self.description: plot_description = _create_label(self.description)
118
-
119
- if isinstance(self.y, str):
120
- label = f'{self.y} | {self.prediction}' if self.prediction else self.y
121
- if self.prediction:
122
- font_color = 'green' if self.y == self.prediction else 'red'
123
- y_label = _create_label(r'\(\color{' + font_color + '} {' + label + '}\)')
124
- else:
125
- y_label = _create_label(label)
126
- else: y_label = _create_label(' ')
127
-
128
- slice_slider = _create_slider(
129
- slider_min = min(self.slice_range),
130
- slider_max = max(self.slice_range),
131
- value = max(self.slice_range)//2,
132
- readout = True)
133
-
134
- toggle_mask_button = _create_togglebutton('Show Mask', True)
135
-
136
- range_slider = _create_slider(
137
- slider_min = self.x.min().numpy(),
138
- slider_max = self.x.max().numpy(),
139
- value = [self.x.min().numpy(), self.x.max().numpy()],
140
- slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
141
- step = 0.01 if torch.is_floating_point(self.x) else 1,
142
- readout=True)
143
-
144
- image_output = widgets.interactive_output(
145
- f = self._plot_slice,
146
- controls = {'im_slice': slice_slider,
147
- 'with_mask': toggle_mask_button,
148
- 'px_range': range_slider})
149
-
150
- image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
151
- image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
152
-
153
- if self.description: items.append(plot_description)
154
- items.append(y_label)
155
- items.append(range_slider)
156
- items.append(image_output)
157
- if isinstance(self.y, torch.Tensor):
158
- slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
159
- items.append(slice_slider)
160
-
161
- image_box=widgets.VBox(
162
- items,
163
- layout = widgets.Layout(
164
- border = 'none',
165
- margin = '10px 5px 0px 0px',
166
- padding = '5px'))
167
-
168
- return image_box
169
-
170
- def _generate_views(self):
171
- image_box = self._create_image_box(self.figsize)
172
- self.box = widgets.HBox(children=[image_box])
173
-
174
- @property
175
- def image_box(self):
176
- return self._create_image_box(self.figsize)
177
-
178
- def show(self):
179
- self._generate_views()
180
- plt.style.use('default')
181
- display(self.box)
182
-
183
-
184
- class DicomExplorer(BasicViewer):
185
- """ DICOM viewer for basic image analysis inside iPython notebooks.
186
- Can display a single 3D volume together with a segmentation mask, a histogram
187
- of voxel/pixel values and some summary statistics.
188
- Allows simple windowing by clipping the pixel/voxel values to a region, which
189
- can be manually specified.
190
-
191
- """
192
-
193
- vbox_layout = widgets.Layout(
194
- margin = '10px 5px 5px 5px',
195
- padding = '5px',
196
- display='flex',
197
- flex_flow='column',
198
- align_items='center',
199
- min_width = '250px')
200
-
201
- def _plot_hist(self, px_range):
202
- x = self.x.numpy().flatten()
203
- fig, ax = plt.subplots(figsize=self.figsize)
204
- N, bins, patches = plt.hist(x, 100, color='grey')
205
- lwr = int(px_range[0] * 100/max(x))
206
- upr = int(np.ceil(px_range[1] * 100/max(x)))
207
-
208
- for i in range(0,lwr):
209
- patches[i].set_facecolor('grey' if lwr > 0 else 'darkblue')
210
- for i in range(lwr, upr):
211
- patches[i].set_facecolor('darkblue')
212
- for i in range(upr,100):
213
- patches[i].set_facecolor('grey' if upr < 100 else 'darkblue')
214
-
215
- plt.show()
216
-
217
- def _image_summary(self, px_range):
218
- x = self.x.clip(*px_range)
219
-
220
- diffs = x - x.mean()
221
- var = torch.mean(torch.pow(diffs, 2.0))
222
- std = torch.pow(var, 0.5)
223
- zscores = diffs / std
224
- skews = torch.mean(torch.pow(zscores, 3.0))
225
- kurt = torch.mean(torch.pow(zscores, 4.0)) - 3.0
226
-
227
- table = f'Statistics:\n' + \
228
- f' Mean px value: {x.mean()} \n' + \
229
- f' Std of px values: {x.std()} \n' + \
230
- f' Min px value: {x.min()} \n' + \
231
- f' Max px value: {x.max()} \n' + \
232
- f' Median px value: {x.median()} \n' + \
233
- f' Skewness: {skews} \n' + \
234
- f' Kurtosis: {kurt} \n\n' + \
235
- f'Tensor properties \n' + \
236
- f' Tensor shape: {tuple(x.shape)}\n' + \
237
- f' Tensor dtype: {x.dtype}'
238
- print(table)
239
-
240
- def _generate_views(self):
241
-
242
- slice_slider = _create_slider(
243
- slider_min = min(self.slice_range),
244
- slider_max = max(self.slice_range),
245
- value = max(self.slice_range)//2,
246
- readout = True)
247
-
248
- toggle_mask_button = _create_togglebutton('Show Mask', True)
249
-
250
- range_slider = _create_slider(
251
- slider_min = self.x.min().numpy(),
252
- slider_max = self.x.max().numpy(),
253
- value = [self.x.min().numpy(), self.x.max().numpy()],
254
- continuous_update=False,
255
- slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
256
- step = 0.01 if torch.is_floating_point(self.x) else 1)
257
-
258
- image_output = widgets.interactive_output(
259
- f = self._plot_slice,
260
- controls = {'im_slice': slice_slider,
261
- 'with_mask': toggle_mask_button,
262
- 'px_range': range_slider})
263
-
264
- image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
265
- image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
266
-
267
- if isinstance(self.y, torch.Tensor):
268
- slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
269
-
270
- hist_output = widgets.interactive_output(
271
- f = self._plot_hist,
272
- controls = {'px_range': range_slider})
273
-
274
- hist_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
275
- hist_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
276
-
277
- toggle_mask_button = _create_togglebutton('Show Mask', True)
278
-
279
- table_output = widgets.interactive_output(
280
- f = self._image_summary,
281
- controls = {'px_range': range_slider})
282
-
283
- table_box = widgets.VBox([table_output], layout=self.vbox_layout)
284
-
285
- hist_box = widgets.VBox(
286
- [hist_output, range_slider],
287
- layout=self.vbox_layout)
288
-
289
- image_box = widgets.VBox(
290
- [image_output, slice_slider],
291
- layout=self.vbox_layout)
292
-
293
- self.box = widgets.HBox(
294
- [image_box, hist_box, table_box],
295
- layout = widgets.Layout(
296
- border = 'solid 1px lightgrey',
297
- margin = '10px 5px 0px 0px',
298
- padding = '5px',
299
- width = f'{self.figsize[1]*2 + 3}in'))
300
-
301
-
302
- class ListViewer(object):
303
- """ Display multipple images with their masks or labels/predictions.
304
- Arguments:
305
- x (tuple, list): Tensor objects to view
306
- y (tuple, list): Tensor objects (in case of segmentation task) or class labels as string.
307
- predictions (str): Class predictions
308
- cmap: colormap for display of `x`
309
- max_n: maximum number of items to display
310
- """
311
-
312
- def __init__(self, x:(list, tuple), y=None, prediction:str=None, description: str=None,
313
- figsize=(4, 4), cmap:str='bone', max_n = 9):
314
- self.slice_range = (1, len(x))
315
- x = x[0:max_n]
316
- if y: y = y[0:max_n]
317
- self.x=x
318
- self.y=y
319
- self.prediction=prediction
320
- self.description=description
321
- self.figsize=figsize
322
- self.cmap=cmap
323
- self.max_n=max_n
324
-
325
- def _generate_views(self):
326
- n_images = len(self.x)
327
- image_grid, image_list = [], []
328
-
329
- for i in range(0, n_images):
330
- image = self.x[i]
331
- mask = self.y[i] if isinstance(self.y, list) else None
332
- pred = self.prediction[i] if self.prediction else None
333
-
334
- image_list.append(
335
- BasicViewer(
336
- x = image,
337
- y = mask,
338
- prediction = pred,
339
- figsize = self.figsize,
340
- cmap = self.cmap)
341
- .image_box)
342
-
343
- if (i+1) % np.ceil(np.sqrt(n_images)) == 0 or i == n_images - 1:
344
- image_grid.append(widgets.HBox(image_list))
345
- image_list = []
346
-
347
- self.box = widgets.VBox(children=image_grid)
348
-
349
- def show(self):
350
- self._generate_views()
351
- plt.style.use('default')
352
- display(self.box)