Spaces:
Runtime error
Runtime error
Upload 9 files
Browse files- prostate/data.py +229 -0
- prostate/loss.py +23 -0
- prostate/model.py +16 -0
- prostate/optimizer.py +37 -0
- prostate/report.py +184 -0
- prostate/train.py +482 -0
- prostate/transforms.py +363 -0
- prostate/utils.py +56 -0
- prostate/viewer.py +352 -0
prostate/data.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# create dataloaders form csv file
|
2 |
+
|
3 |
+
## ---------- imports ----------
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import shutil
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
from typing import Union
|
10 |
+
from monai.utils import first
|
11 |
+
from functools import partial
|
12 |
+
from collections import namedtuple
|
13 |
+
from monai.data import DataLoader as MonaiDataLoader
|
14 |
+
|
15 |
+
from . import transforms
|
16 |
+
from .utils import num_workers
|
17 |
+
|
18 |
+
|
19 |
+
def import_dataset(config: dict):
|
20 |
+
if config.data.dataset_type == 'persistent':
|
21 |
+
from monai.data import PersistentDataset
|
22 |
+
if os.path.exists(config.data.cache_dir):
|
23 |
+
shutil.rmtree(config.data.cache_dir) # rm previous cache DS
|
24 |
+
os.makedirs(config.data.cache_dir, exist_ok = True)
|
25 |
+
Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir)
|
26 |
+
elif config.data.dataset_type == 'cache':
|
27 |
+
from monai.data import CacheDataset
|
28 |
+
raise NotImplementedError('CacheDataset not yet implemented')
|
29 |
+
else:
|
30 |
+
from monai.data import Dataset
|
31 |
+
return Dataset
|
32 |
+
|
33 |
+
|
34 |
+
class DataLoader(MonaiDataLoader):
|
35 |
+
"overwrite monai DataLoader for enhanced viewing capabilities"
|
36 |
+
|
37 |
+
def show_batch(self,
|
38 |
+
image_key: str='image',
|
39 |
+
label_key: str='label',
|
40 |
+
image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2),
|
41 |
+
label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)):
|
42 |
+
"""Args:
|
43 |
+
image_key: dict key name for image to view
|
44 |
+
label_key: dict kex name for corresponding label. Can be a tensor or str
|
45 |
+
image_transform: transform input before it is passed to the viewer to ensure
|
46 |
+
ndim of the image is equal to 3 and image is oriented correctly
|
47 |
+
label_transform: transform labels before passed to the viewer, to ensure
|
48 |
+
segmentations masks have same shape and orientations as images. Should be
|
49 |
+
identity function of labels are str.
|
50 |
+
"""
|
51 |
+
from .viewer import ListViewer
|
52 |
+
|
53 |
+
batch = first(self)
|
54 |
+
image = torch.unbind(batch[image_key], 0)
|
55 |
+
label = torch.unbind(batch[label_key], 0)
|
56 |
+
|
57 |
+
ListViewer([image_transform(im) for im in image],
|
58 |
+
[label_transform(im) for im in label]).show()
|
59 |
+
|
60 |
+
# TODO
|
61 |
+
## Work with 3 dataloaders
|
62 |
+
|
63 |
+
def segmentation_dataloaders(config: dict,
|
64 |
+
train: bool = None,
|
65 |
+
valid: bool = None,
|
66 |
+
test: bool = None,
|
67 |
+
):
|
68 |
+
"""Create segmentation dataloaders
|
69 |
+
Args:
|
70 |
+
config: config file
|
71 |
+
train: whether to return a train DataLoader
|
72 |
+
valid: whether to return a valid DataLoader
|
73 |
+
test: whether to return a test DateLoader
|
74 |
+
Args from config:
|
75 |
+
data_dir: base directory for the data
|
76 |
+
csv_name: path to csv file containing filenames and paths
|
77 |
+
image_cols: columns in csv containing path to images
|
78 |
+
label_cols: columns in csv containing path to label files
|
79 |
+
dataset_type: PersistentDataset, CacheDataset and Dataset are supported
|
80 |
+
cache_dir: cache directory to be used by PersistentDataset
|
81 |
+
batch_size: batch size for training. Valid and test are always 1
|
82 |
+
debug: run with reduced number of images
|
83 |
+
Returns:
|
84 |
+
list of:
|
85 |
+
train_loader: DataLoader (optional, if train==True)
|
86 |
+
valid_loader: DataLoader (optional, if valid==True)
|
87 |
+
test_loader: DataLoader (optional, if test==True)
|
88 |
+
"""
|
89 |
+
|
90 |
+
## parse needed rguments from config
|
91 |
+
if train is None: train = config.data.train
|
92 |
+
if valid is None: valid = config.data.valid
|
93 |
+
if test is None: test = config.data.test
|
94 |
+
|
95 |
+
data_dir = config.data.data_dir
|
96 |
+
train_csv = config.data.train_csv
|
97 |
+
valid_csv = config.data.valid_csv
|
98 |
+
test_csv = config.data.test_csv
|
99 |
+
image_cols = config.data.image_cols
|
100 |
+
label_cols = config.data.label_cols
|
101 |
+
dataset_type = config.data.dataset_type
|
102 |
+
cache_dir = config.data.cache_dir
|
103 |
+
batch_size = config.data.batch_size
|
104 |
+
debug = config.debug
|
105 |
+
|
106 |
+
## ---------- data dicts ----------
|
107 |
+
|
108 |
+
# first a global data dict, containing only the filepath from image_cols and label_cols is created. For this,
|
109 |
+
# the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an
|
110 |
+
# individual dict, as expected by monai
|
111 |
+
|
112 |
+
if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols]
|
113 |
+
if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols]
|
114 |
+
|
115 |
+
train_df = pd.read_csv(train_csv)
|
116 |
+
valid_df = pd.read_csv(valid_csv)
|
117 |
+
test_df = pd.read_csv(test_csv)
|
118 |
+
if debug:
|
119 |
+
train_df = train_df.sample(25)
|
120 |
+
valid_df = valid_df.sample(5)
|
121 |
+
|
122 |
+
train_df['split']='train'
|
123 |
+
valid_df['split']='valid'
|
124 |
+
test_df['split']='test'
|
125 |
+
whole_df = []
|
126 |
+
if train: whole_df += [train_df]
|
127 |
+
if valid: whole_df += [valid_df]
|
128 |
+
if test: whole_df += [test_df]
|
129 |
+
df = pd.concat(whole_df)
|
130 |
+
cols = image_cols + label_cols
|
131 |
+
for col in cols:
|
132 |
+
# create absolute file name from relative fn in df and data_dir
|
133 |
+
df[col] = [os.path.join(data_dir, fn) for fn in df[col]]
|
134 |
+
if not os.path.exists(list(df[col])[0]):
|
135 |
+
raise FileNotFoundError(list(df[col])[0])
|
136 |
+
|
137 |
+
|
138 |
+
data_dict = [dict(row[1]) for row in df[cols].iterrows()]
|
139 |
+
# data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer.
|
140 |
+
# The data_dict looks like this:
|
141 |
+
# [
|
142 |
+
# {'image_col_1': 'data_dir/path/to/image1',
|
143 |
+
# 'image_col_2': 'data_dir/path/to/image2'
|
144 |
+
# 'label_col_1': 'data_dir/path/to/label1},
|
145 |
+
# {'image_col_1': 'data_dir/path/to/image1',
|
146 |
+
# 'image_col_2': 'data_dir/path/to/image2'
|
147 |
+
# 'label_col_1': 'data_dir/path/to/label1},
|
148 |
+
# ...]
|
149 |
+
# Filename should now be absolute or relative to working directory
|
150 |
+
|
151 |
+
# now we create separate data dicts for train, valid and test data respectively
|
152 |
+
assert train or test or valid, 'No dataset type is specified (train/valid or test)'
|
153 |
+
|
154 |
+
if test:
|
155 |
+
test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test')))
|
156 |
+
|
157 |
+
if valid:
|
158 |
+
val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid')))
|
159 |
+
|
160 |
+
if train:
|
161 |
+
train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train')))
|
162 |
+
|
163 |
+
# transforms are specified in transforms.py and are just loaded here
|
164 |
+
if train: train_transforms = transforms.get_train_transforms(config)
|
165 |
+
if valid: val_transforms = transforms.get_val_transforms(config)
|
166 |
+
if test: test_transforms = transforms.get_test_transforms(config)
|
167 |
+
|
168 |
+
|
169 |
+
## ---------- construct dataloaders ----------
|
170 |
+
Dataset=import_dataset(config)
|
171 |
+
data_loaders = []
|
172 |
+
if train:
|
173 |
+
train_ds = Dataset(
|
174 |
+
data=train_files,
|
175 |
+
transform=train_transforms
|
176 |
+
)
|
177 |
+
train_loader = DataLoader(
|
178 |
+
train_ds,
|
179 |
+
batch_size=batch_size,
|
180 |
+
num_workers=num_workers(),
|
181 |
+
shuffle=True
|
182 |
+
)
|
183 |
+
data_loaders.append(train_loader)
|
184 |
+
|
185 |
+
if valid:
|
186 |
+
val_ds = Dataset(
|
187 |
+
data=val_files,
|
188 |
+
transform=val_transforms
|
189 |
+
)
|
190 |
+
val_loader = DataLoader(
|
191 |
+
val_ds,
|
192 |
+
batch_size=1,
|
193 |
+
num_workers=num_workers(),
|
194 |
+
shuffle=False
|
195 |
+
)
|
196 |
+
data_loaders.append(val_loader)
|
197 |
+
|
198 |
+
if test:
|
199 |
+
test_ds = Dataset(
|
200 |
+
data=test_files,
|
201 |
+
transform=test_transforms
|
202 |
+
)
|
203 |
+
test_loader = DataLoader(
|
204 |
+
test_ds,
|
205 |
+
batch_size=1,
|
206 |
+
num_workers=num_workers(),
|
207 |
+
shuffle=False
|
208 |
+
)
|
209 |
+
data_loaders.append(test_loader)
|
210 |
+
|
211 |
+
# if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders,
|
212 |
+
# so it is clear which DataLoader is train/valid or test
|
213 |
+
|
214 |
+
if len(data_loaders) == 1:
|
215 |
+
return data_loaders[0]
|
216 |
+
else:
|
217 |
+
DataLoaders = namedtuple(
|
218 |
+
'DataLoaders',
|
219 |
+
# create str with specification of loader type if train and test are true but
|
220 |
+
# valid is false string will be 'train test'
|
221 |
+
' '.join(
|
222 |
+
[
|
223 |
+
'train' if train else '',
|
224 |
+
'valid' if valid else '',
|
225 |
+
'test' if test else ''
|
226 |
+
]
|
227 |
+
).strip()
|
228 |
+
)
|
229 |
+
return DataLoaders(*data_loaders)
|
prostate/loss.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import monai
|
2 |
+
from .utils import load_config
|
3 |
+
|
4 |
+
|
5 |
+
def get_loss(config: dict):
|
6 |
+
"""Create a loss function of `type` with specific keyword arguments from config.
|
7 |
+
Example:
|
8 |
+
|
9 |
+
config.loss
|
10 |
+
>>> {'DiceCELoss': {'include_background': False, 'softmax': True, 'to_onehot_y': True}}
|
11 |
+
|
12 |
+
get_loss(config)
|
13 |
+
>>> DiceCELoss(
|
14 |
+
>>> (dice): DiceLoss()
|
15 |
+
>>> (cross_entropy): CrossEntropyLoss()
|
16 |
+
>>> )
|
17 |
+
|
18 |
+
"""
|
19 |
+
loss_type = list(config.loss.keys())[0]
|
20 |
+
loss_config = config.loss[loss_type]
|
21 |
+
loss_fun = getattr(monai.losses, loss_type)
|
22 |
+
loss = loss_fun(**loss_config)
|
23 |
+
return loss
|
prostate/model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# create a standard UNet
|
2 |
+
|
3 |
+
from monai.networks.nets import UNet
|
4 |
+
|
5 |
+
def get_model(config: dict):
|
6 |
+
return UNet(
|
7 |
+
spatial_dims=config.ndim,
|
8 |
+
in_channels=len(config.data.image_cols),
|
9 |
+
out_channels=config.model.out_channels,
|
10 |
+
channels=config.model.channels,
|
11 |
+
strides=config.model.strides,
|
12 |
+
num_res_units=config.model.num_res_units,
|
13 |
+
act=config.model.act,
|
14 |
+
norm=config.model.norm,
|
15 |
+
dropout=config.model.dropout,
|
16 |
+
)
|
prostate/optimizer.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import monai
|
3 |
+
from .utils import load_config
|
4 |
+
|
5 |
+
|
6 |
+
def get_optimizer(model: torch.nn.Module,
|
7 |
+
config: dict):
|
8 |
+
"""Create an optimizer of `type` with specific keyword arguments from config.
|
9 |
+
Example:
|
10 |
+
|
11 |
+
config.optimizer
|
12 |
+
>>> {'Novograd': {'lr': 0.001, 'weight_decay': 0.01}}
|
13 |
+
|
14 |
+
get_optimizer(model, config)
|
15 |
+
>>> Novograd (
|
16 |
+
>>> Parameter Group 0
|
17 |
+
>>> amsgrad: False
|
18 |
+
>>> betas: (0.9, 0.999)
|
19 |
+
>>> eps: 1e-08
|
20 |
+
>>> grad_averaging: False
|
21 |
+
>>> lr: 0.0001
|
22 |
+
>>> weight_decay: 0.001
|
23 |
+
>>> )
|
24 |
+
|
25 |
+
"""
|
26 |
+
optimizer_type = list(config.optimizer.keys())[0]
|
27 |
+
opt_config = config.optimizer[optimizer_type]
|
28 |
+
if hasattr(torch.optim, optimizer_type):
|
29 |
+
optimizer_fun = getattr(torch.optim, optimizer_type)
|
30 |
+
elif hasattr(monai.optimizers, optimizer_type):
|
31 |
+
optimizer_fun = getattr(monai.optimizers, optimizer_type)
|
32 |
+
else:
|
33 |
+
raise ValueError(f'Optimizer {optimizer_type} not found')
|
34 |
+
optimizer = optimizer_fun(model.parameters(), **opt_config)
|
35 |
+
return optimizer
|
36 |
+
|
37 |
+
|
prostate/report.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import cv2
|
4 |
+
import tqdm
|
5 |
+
import torch
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
class ReportGenerator():
|
12 |
+
"Generate markdown document, summarizing the training"
|
13 |
+
|
14 |
+
def __init__(self, run_id, out_dir=None, log_dir=None):
|
15 |
+
|
16 |
+
self.run_id, self.out_dir, self.log_dir = run_id, out_dir, log_dir
|
17 |
+
|
18 |
+
if log_dir:
|
19 |
+
self.train_logs = pd.read_csv(os.path.join(log_dir, 'train_logs.csv'))
|
20 |
+
self.metric_logs = pd.read_csv(os.path.join(log_dir, 'metric_logs.csv'))
|
21 |
+
if out_dir:
|
22 |
+
self.dice = pd.read_csv(os.path.join(out_dir, 'MeanDice_raw.csv'))
|
23 |
+
self.hausdorf = pd.read_csv(os.path.join(out_dir, 'HausdorffDistance_raw.csv'))
|
24 |
+
self.surface = pd.read_csv(os.path.join(out_dir, 'SurfaceDistance_raw.csv'))
|
25 |
+
|
26 |
+
self.mean_metrics = pd.DataFrame(
|
27 |
+
{"mean_dice" : [round(np.mean(self.dice[col]),3) for col in self.dice if col.startswith('class')],
|
28 |
+
"mean_hausdorf" : [round(np.mean(self.hausdorf[col]),3) for col in self.hausdorf if col.startswith('class')],
|
29 |
+
"mean_surface" : [round(np.mean(self.surface[col]),3) for col in self.surface if col.startswith('class')]
|
30 |
+
}).transpose()
|
31 |
+
|
32 |
+
def generate_report(self, loss_plot=True, metric_plot=True, boxplots=True, animation=True):
|
33 |
+
fn = os.path.join(self.run_id, 'report', 'SegmentationReport.md')
|
34 |
+
os.makedirs(os.path.join(self.run_id, 'report'), exist_ok=True)
|
35 |
+
with open(fn, 'w+') as f:
|
36 |
+
f.write('# Segmentation Report\n\n')
|
37 |
+
|
38 |
+
if loss_plot:
|
39 |
+
fig = self.plot_loss(self.train_logs, self.metric_logs)
|
40 |
+
plt.savefig(os.path.join(self.run_id, 'report', 'loss_and_lr.png'), dpi = 150)
|
41 |
+
|
42 |
+
with open(fn, 'a') as f:
|
43 |
+
f.write('## Loss, LR-Schedule and Key Metric\n')
|
44 |
+
f.write('\n\n')
|
45 |
+
|
46 |
+
if metric_plot:
|
47 |
+
fig = plt.figure("metrics", (18, 6))
|
48 |
+
|
49 |
+
ax = plt.subplot(1, 3, 1)
|
50 |
+
plt.ylim([0,1])
|
51 |
+
plt.title("Mean Dice")
|
52 |
+
plt.xlabel("epoch")
|
53 |
+
plt.plot(self.metric_logs.index, self.metric_logs.MeanDice)
|
54 |
+
|
55 |
+
ax = plt.subplot(1, 3, 2)
|
56 |
+
plt.title("Mean Hausdorff Distance")
|
57 |
+
plt.xlabel("epoch")
|
58 |
+
plt.plot(self.metric_logs.index, self.metric_logs.HausdorffDistance)
|
59 |
+
|
60 |
+
ax = plt.subplot(1, 3, 3)
|
61 |
+
plt.title("Mean Surface Distance")
|
62 |
+
plt.xlabel("epoch")
|
63 |
+
plt.plot(self.metric_logs.index, self.metric_logs.SurfaceDistance)
|
64 |
+
|
65 |
+
plt.savefig(os.path.join(self.run_id, 'report', 'metrics.png'), dpi = 150)
|
66 |
+
fig.clear()
|
67 |
+
plt.close()
|
68 |
+
|
69 |
+
with open(fn, 'a') as f:
|
70 |
+
f.write('## Metrics\n')
|
71 |
+
f.write('\n\n')
|
72 |
+
|
73 |
+
if boxplots:
|
74 |
+
fig = plt.figure("boxplots", (18, 6))
|
75 |
+
|
76 |
+
ax = plt.subplot(1, 3, 1)
|
77 |
+
plt.title("Dice")
|
78 |
+
plt.xlabel("class")
|
79 |
+
plt.boxplot(self.dice[[col for col in self.dice if col.startswith('class')]])
|
80 |
+
|
81 |
+
ax = plt.subplot(1, 3, 2)
|
82 |
+
plt.title("Hausdorff Distance")
|
83 |
+
plt.xlabel("class")
|
84 |
+
plt.boxplot(self.hausdorf[[col for col in self.hausdorf if col.startswith('class')]])
|
85 |
+
|
86 |
+
ax = plt.subplot(1, 3, 3)
|
87 |
+
plt.title("Surface Distance")
|
88 |
+
plt.xlabel("class")
|
89 |
+
plt.boxplot(self.surface[[col for col in self.surface if col.startswith('class')]])
|
90 |
+
|
91 |
+
plt.savefig(os.path.join(self.run_id, 'report', 'boxplots.png'),dpi = 150)
|
92 |
+
|
93 |
+
fig.clear()
|
94 |
+
plt.close()
|
95 |
+
|
96 |
+
with open(fn, 'a') as f:
|
97 |
+
f.write(f"## Individual metrics\n\n")
|
98 |
+
f.write(f"{self.mean_metrics.to_markdown()}\n\n")
|
99 |
+
f.write(f"\n\n")
|
100 |
+
if animation:
|
101 |
+
self.generate_gif()
|
102 |
+
with open(fn, 'a') as f:
|
103 |
+
f.write('## Visualization of progress\n')
|
104 |
+
f.write('\n\n')
|
105 |
+
|
106 |
+
def plot_loss(self, train_logs, metric_logs):
|
107 |
+
iteration = train_logs.iteration/sum(train_logs.epoch == 1)
|
108 |
+
fig = plt.figure("loss and lr", (12, 6))
|
109 |
+
|
110 |
+
y_max = max(metric_logs.eval_loss) + 0.5
|
111 |
+
if y_max > 3: y_max = 3
|
112 |
+
|
113 |
+
ax = plt.subplot(1, 2, 1)
|
114 |
+
plt.ylim([0,y_max])
|
115 |
+
plt.title("Epoch Average Loss")
|
116 |
+
plt.xlabel("epoch")
|
117 |
+
plt.plot(iteration, train_logs.loss)
|
118 |
+
plt.plot(metric_logs.index, metric_logs.eval_loss)
|
119 |
+
|
120 |
+
ax = plt.subplot(1, 2, 2)
|
121 |
+
ax.set_yscale('log')
|
122 |
+
plt.title("LR Schedule")
|
123 |
+
plt.xlabel("epoch")
|
124 |
+
plt.plot(iteration, train_logs.lr)
|
125 |
+
return fig
|
126 |
+
|
127 |
+
def get_arr_from_fig(self, fig, dpi=180):
|
128 |
+
buf = io.BytesIO()
|
129 |
+
fig.savefig(buf, format="png", dpi=dpi)
|
130 |
+
buf.seek(0)
|
131 |
+
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
132 |
+
buf.close()
|
133 |
+
img = cv2.imdecode(img_arr, 1)
|
134 |
+
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
135 |
+
return img
|
136 |
+
|
137 |
+
def get_slices(self, im, slices):
|
138 |
+
ims = torch.unbind(im[:, :, slices], -1) # extract n slices
|
139 |
+
ims = [i.transpose(0,1).flip(0) for i in ims] # rotate slices 90 degrees
|
140 |
+
if len(slices) > 4 and len(slices) % 2 == 0:
|
141 |
+
n = len(slices) // 2
|
142 |
+
ims1 = torch.cat(ims[0:n], 1)
|
143 |
+
ims2 = torch.cat(ims[n:], 1)
|
144 |
+
return torch.cat([ims1, ims2], 0)
|
145 |
+
else:
|
146 |
+
return torch.cat(ims, 1) # create tile
|
147 |
+
|
148 |
+
def plot_images(self, fns, slices, cmap='Greys_r', figsize=15, **kwargs):
|
149 |
+
ims = [torch.load(os.path.join(self.out_dir, 'preds', fn)).cpu().argmax(0) for fn in fns]
|
150 |
+
ims = [self.get_slices(im, slices) for im in ims]
|
151 |
+
ims = torch.cat(ims, 0)
|
152 |
+
plt.figure(figsize=(figsize,figsize))
|
153 |
+
plt.imshow(ims, cmap=cmap, **kwargs)
|
154 |
+
plt.axis('off')
|
155 |
+
|
156 |
+
def load_segmentation_image(self, fn):
|
157 |
+
im = torch.load(fn).cpu().unsqueeze(0)
|
158 |
+
im = torch.nn.functional.interpolate(im, (224, 224, 112))
|
159 |
+
im = im.argmax(1).squeeze()
|
160 |
+
im = self.get_slices(im, slices = (40, 48, 56, 74, 82, 90))
|
161 |
+
im = im/im.max() * 255
|
162 |
+
return im
|
163 |
+
|
164 |
+
def generate_gif(self):
|
165 |
+
with imageio.get_writer(
|
166 |
+
os.path.join(self.run_id,'report','progress.gif'),
|
167 |
+
mode='I',
|
168 |
+
fps = max(self.train_logs.epoch) // 10) as writer: # make gif 10 seconds
|
169 |
+
for epoch in tqdm.tqdm(list(self.train_logs.epoch.unique())):
|
170 |
+
seg_fn = os.path.join(self.out_dir, 'preds', f"pred_epoch_{epoch}.pt")
|
171 |
+
if os.path.exists(seg_fn): im = self.load_segmentation_image(seg_fn)
|
172 |
+
|
173 |
+
plt_train_logs = self.train_logs[self.train_logs.epoch <= epoch]
|
174 |
+
loss_plt = self.plot_loss(plt_train_logs, self.metric_logs[:epoch])
|
175 |
+
loss_fig = self.get_arr_from_fig(loss_plt)[:,:,0]
|
176 |
+
|
177 |
+
new_shape = im.shape[1], int(loss_fig.shape[0] / loss_fig.shape[1] * im.shape[1])
|
178 |
+
loss_fig = cv2.resize(loss_fig, (im.shape[1], im.shape[0]))
|
179 |
+
|
180 |
+
images = torch.cat([im, torch.tensor(loss_fig)], 0).numpy().astype(np.uint8)
|
181 |
+
writer.append_data(images)
|
182 |
+
|
183 |
+
loss_plt.clear()
|
184 |
+
plt.close()
|
prostate/train.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import munch
|
4 |
+
import torch
|
5 |
+
import ignite
|
6 |
+
import monai
|
7 |
+
import shutil
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from typing import Union, List, Callable
|
11 |
+
from ignite.contrib.handlers.tqdm_logger import ProgressBar
|
12 |
+
from monai.handlers import (
|
13 |
+
CheckpointSaver,
|
14 |
+
StatsHandler,
|
15 |
+
TensorBoardStatsHandler,
|
16 |
+
TensorBoardImageHandler,
|
17 |
+
ValidationHandler,
|
18 |
+
from_engine,
|
19 |
+
MeanDice,
|
20 |
+
EarlyStopHandler,
|
21 |
+
MetricLogger,
|
22 |
+
MetricsSaver
|
23 |
+
)
|
24 |
+
|
25 |
+
from .data import segmentation_dataloaders
|
26 |
+
from .model import get_model
|
27 |
+
from .optimizer import get_optimizer
|
28 |
+
from .loss import get_loss
|
29 |
+
from .transforms import get_val_post_transforms
|
30 |
+
from .utils import USE_AMP
|
31 |
+
|
32 |
+
def loss_logger(engine):
|
33 |
+
"write loss and lr of each iteration/epoch to file"
|
34 |
+
iteration=engine.state.iteration
|
35 |
+
epoch=engine.state.epoch
|
36 |
+
loss=[o['loss'] for o in engine.state.output]
|
37 |
+
loss=sum(loss)/len(loss)
|
38 |
+
lr=engine.optimizer.param_groups[0]['lr']
|
39 |
+
log_file=os.path.join(engine.config.log_dir, 'train_logs.csv')
|
40 |
+
if not os.path.exists(log_file):
|
41 |
+
with open(log_file, 'w+') as f:
|
42 |
+
f.write('iteration,epoch,loss,lr\n')
|
43 |
+
with open(log_file, 'a') as f:
|
44 |
+
f.write(f'{iteration},{epoch},{loss},{lr}\n')
|
45 |
+
|
46 |
+
def metric_logger(engine):
|
47 |
+
"write `metrics` after each epoch to file"
|
48 |
+
if engine.state.epoch > 1: # only key metric is calcualted in 1st epoch, needs fix
|
49 |
+
metric_names=[k for k in engine.state.metrics.keys()]
|
50 |
+
metrics=[str(engine.state.metrics[mn]) for mn in metric_names]
|
51 |
+
log_file=os.path.join(engine.config.log_dir, 'metric_logs.csv')
|
52 |
+
if not os.path.exists(log_file):
|
53 |
+
with open(log_file, 'w+') as f:
|
54 |
+
f.write(','.join(metric_names) + '\n')
|
55 |
+
with open(log_file, 'a') as f:
|
56 |
+
f.write(','.join(metrics) + '\n')
|
57 |
+
|
58 |
+
def pred_logger(engine):
|
59 |
+
"save `pred` each time metric improves"
|
60 |
+
epoch=engine.state.epoch
|
61 |
+
root = os.path.join(engine.config.out_dir, 'preds')
|
62 |
+
if not os.path.exists(root):
|
63 |
+
os.makedirs(root)
|
64 |
+
torch.save(
|
65 |
+
engine.state.output[0]['label'],
|
66 |
+
os.path.join(root, f'label.pt')
|
67 |
+
)
|
68 |
+
torch.save(
|
69 |
+
engine.state.output[0]['image'],
|
70 |
+
os.path.join(root, f'image.pt')
|
71 |
+
)
|
72 |
+
|
73 |
+
if epoch==engine.state.best_metric_epoch:
|
74 |
+
torch.save(
|
75 |
+
engine.state.output[0]['pred'],
|
76 |
+
os.path.join(root, f'pred_epoch_{epoch}.pt')
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
def get_val_handlers(
|
81 |
+
network: torch.nn.Module,
|
82 |
+
config: dict
|
83 |
+
) -> list:
|
84 |
+
"""Create default handlers for model validation
|
85 |
+
Args:
|
86 |
+
network:
|
87 |
+
nn.Module subclass, the model to train
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
a list of default handlers for validation: [
|
91 |
+
StatsHandler:
|
92 |
+
???
|
93 |
+
TensorBoardStatsHandler:
|
94 |
+
Save loss from validation to `config.log_dir`, allow logging with TensorBoard
|
95 |
+
CheckpointSaver:
|
96 |
+
Save best model to `config.model_dir`
|
97 |
+
]
|
98 |
+
"""
|
99 |
+
|
100 |
+
val_handlers=[
|
101 |
+
StatsHandler(
|
102 |
+
tag_name="metric_logger",
|
103 |
+
epoch_print_logger=metric_logger,
|
104 |
+
output_transform=lambda x: None
|
105 |
+
),
|
106 |
+
StatsHandler(
|
107 |
+
tag_name="pred_logger",
|
108 |
+
epoch_print_logger=pred_logger,
|
109 |
+
output_transform=lambda x: None
|
110 |
+
),
|
111 |
+
TensorBoardStatsHandler(
|
112 |
+
log_dir=config.log_dir,
|
113 |
+
# tag_name="val_mean_dice",
|
114 |
+
output_transform=lambda x: None
|
115 |
+
),
|
116 |
+
TensorBoardImageHandler(
|
117 |
+
log_dir=config.log_dir,
|
118 |
+
batch_transform=from_engine(["image", "label"]),
|
119 |
+
output_transform=from_engine(["pred"]),
|
120 |
+
),
|
121 |
+
CheckpointSaver(
|
122 |
+
save_dir=config.model_dir,
|
123 |
+
save_dict={f"network_{config.run_id}": network},
|
124 |
+
save_key_metric=True
|
125 |
+
),
|
126 |
+
|
127 |
+
]
|
128 |
+
|
129 |
+
return val_handlers
|
130 |
+
|
131 |
+
|
132 |
+
def get_train_handlers(
|
133 |
+
evaluator: monai.engines.SupervisedEvaluator,
|
134 |
+
config: dict
|
135 |
+
) -> list:
|
136 |
+
"""Create default handlers for model training
|
137 |
+
Args:
|
138 |
+
evaluator: an engine of type `monai.engines.SupervisedEvaluator` for evaluations
|
139 |
+
every epoch
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
list of default handlers for training: [
|
143 |
+
ValidationHandler:
|
144 |
+
Allows model validation every epoch
|
145 |
+
StatsHandler:
|
146 |
+
???
|
147 |
+
TensorBoardStatsHandler:
|
148 |
+
Save loss from validation to `config.log_dir`, allow logging with TensorBoard
|
149 |
+
]
|
150 |
+
"""
|
151 |
+
|
152 |
+
train_handlers=[
|
153 |
+
ValidationHandler(
|
154 |
+
validator=evaluator,
|
155 |
+
interval=1,
|
156 |
+
epoch_level=True
|
157 |
+
),
|
158 |
+
StatsHandler(
|
159 |
+
tag_name="train_loss",
|
160 |
+
output_transform=from_engine(
|
161 |
+
["loss"],
|
162 |
+
first=True
|
163 |
+
)
|
164 |
+
),
|
165 |
+
StatsHandler(
|
166 |
+
tag_name='loss_logger',
|
167 |
+
iteration_print_logger=loss_logger
|
168 |
+
),
|
169 |
+
TensorBoardStatsHandler(
|
170 |
+
log_dir=config.log_dir,
|
171 |
+
tag_name="train_loss",
|
172 |
+
output_transform=from_engine(
|
173 |
+
["loss"],
|
174 |
+
first=True
|
175 |
+
),
|
176 |
+
)
|
177 |
+
]
|
178 |
+
|
179 |
+
return train_handlers
|
180 |
+
|
181 |
+
def get_evaluator(
|
182 |
+
config: dict,
|
183 |
+
device: torch.device ,
|
184 |
+
network: torch.nn.Module,
|
185 |
+
val_data_loader: monai.data.dataloader.DataLoader,
|
186 |
+
val_post_transforms: monai.transforms.compose.Compose,
|
187 |
+
val_handlers: Union[Callable, List]=get_val_handlers
|
188 |
+
) -> monai.engines.SupervisedEvaluator:
|
189 |
+
|
190 |
+
"""Create default evaluator for training of a segmentation model
|
191 |
+
Args:
|
192 |
+
device:
|
193 |
+
torch.cuda.device for model and engine
|
194 |
+
network:
|
195 |
+
nn.Module subclass, the model to train
|
196 |
+
val_data_loader:
|
197 |
+
Validation data loader, `monai.data.dataloader.DataLoader` subclass
|
198 |
+
val_post_transforms:
|
199 |
+
function to create transforms OR composed transforms
|
200 |
+
val_handlers:
|
201 |
+
function to create handerls OR List of handlers
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
default evaluator for segmentation of type `monai.engines.SupervisedEvaluator`
|
205 |
+
"""
|
206 |
+
|
207 |
+
if callable(val_handlers): val_handlers=val_handlers()
|
208 |
+
|
209 |
+
evaluator=monai.engines.SupervisedEvaluator(
|
210 |
+
device=device,
|
211 |
+
val_data_loader=val_data_loader,
|
212 |
+
network=network,
|
213 |
+
inferer=monai.inferers.SlidingWindowInferer(
|
214 |
+
roi_size=(96, 96, 96),
|
215 |
+
sw_batch_size=4,
|
216 |
+
overlap=0.5
|
217 |
+
),
|
218 |
+
postprocessing=val_post_transforms,
|
219 |
+
key_val_metric={
|
220 |
+
"val_mean_dice": MeanDice(
|
221 |
+
include_background=False,
|
222 |
+
output_transform=from_engine(
|
223 |
+
["pred", "label"]
|
224 |
+
)
|
225 |
+
)
|
226 |
+
},
|
227 |
+
val_handlers=val_handlers,
|
228 |
+
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
|
229 |
+
amp=USE_AMP,
|
230 |
+
)
|
231 |
+
evaluator.config=config
|
232 |
+
return evaluator
|
233 |
+
|
234 |
+
|
235 |
+
class SegmentationTrainer(monai.engines.SupervisedTrainer):
|
236 |
+
"Default Trainer für supervised segmentation task"
|
237 |
+
def __init__(self,
|
238 |
+
config: dict,
|
239 |
+
progress_bar: bool=True,
|
240 |
+
early_stopping: bool=True,
|
241 |
+
metrics: list=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
|
242 |
+
save_latest_metrics: bool=True
|
243 |
+
):
|
244 |
+
self.config=config
|
245 |
+
self._prepare_dirs()
|
246 |
+
self.config.device=torch.device(self.config.device)
|
247 |
+
|
248 |
+
train_loader, val_loader=segmentation_dataloaders(
|
249 |
+
config=config,
|
250 |
+
train=True,
|
251 |
+
valid=True,
|
252 |
+
test=False
|
253 |
+
)
|
254 |
+
network=get_model(config=config).to(config.device)
|
255 |
+
optimizer=get_optimizer(
|
256 |
+
network,
|
257 |
+
config=config
|
258 |
+
)
|
259 |
+
loss_fn=get_loss(config=config)
|
260 |
+
val_post_transforms=get_val_post_transforms(config=config)
|
261 |
+
val_handlers=get_val_handlers(
|
262 |
+
network,
|
263 |
+
config=config
|
264 |
+
)
|
265 |
+
|
266 |
+
self.evaluator=get_evaluator(
|
267 |
+
config=config,
|
268 |
+
device=config.device,
|
269 |
+
network=network,
|
270 |
+
val_data_loader=val_loader,
|
271 |
+
val_post_transforms=val_post_transforms,
|
272 |
+
val_handlers=val_handlers,
|
273 |
+
|
274 |
+
)
|
275 |
+
train_handlers=get_train_handlers(
|
276 |
+
self.evaluator,
|
277 |
+
config=config
|
278 |
+
)
|
279 |
+
|
280 |
+
super().__init__(
|
281 |
+
device=config.device,
|
282 |
+
max_epochs=self.config.training.max_epochs,
|
283 |
+
train_data_loader=train_loader,
|
284 |
+
network=network,
|
285 |
+
optimizer=optimizer,
|
286 |
+
loss_function=loss_fn,
|
287 |
+
inferer=monai.inferers.SimpleInferer(),
|
288 |
+
train_handlers=train_handlers,
|
289 |
+
amp=USE_AMP,
|
290 |
+
)
|
291 |
+
|
292 |
+
if early_stopping: self._add_early_stopping()
|
293 |
+
if progress_bar: self._add_progress_bars()
|
294 |
+
|
295 |
+
self.schedulers=[]
|
296 |
+
# add different metrics dynamically
|
297 |
+
for m in metrics:
|
298 |
+
getattr(monai.handlers, m)(
|
299 |
+
include_background=False,
|
300 |
+
reduction="mean",
|
301 |
+
output_transform=from_engine(
|
302 |
+
["pred", "label"]
|
303 |
+
)
|
304 |
+
).attach(self.evaluator, m)
|
305 |
+
|
306 |
+
self._add_metrics_logger()
|
307 |
+
# add eval loss to metrics
|
308 |
+
self._add_eval_loss()
|
309 |
+
|
310 |
+
if save_latest_metrics: self._add_metrics_saver()
|
311 |
+
|
312 |
+
|
313 |
+
def _prepare_dirs(self)->None:
|
314 |
+
# create run_id, copy config file for reproducibility
|
315 |
+
os.makedirs(self.config.run_id, exist_ok=True)
|
316 |
+
with open(
|
317 |
+
os.path.join(
|
318 |
+
self.config.run_id,
|
319 |
+
'config.yaml'
|
320 |
+
), 'w+') as f:
|
321 |
+
f.write(yaml.safe_dump(self.config))
|
322 |
+
|
323 |
+
# delete old log_dir
|
324 |
+
if os.path.exists(self.config.log_dir):
|
325 |
+
shutil.rmtree(self.config.log_dir)
|
326 |
+
|
327 |
+
def _add_early_stopping(self) -> None:
|
328 |
+
early_stopping=EarlyStopHandler(
|
329 |
+
patience=self.config.training.early_stopping_patience,
|
330 |
+
min_delta=1e-4,
|
331 |
+
score_function=lambda x: x.state.metrics[x.state.key_metric_name],
|
332 |
+
trainer=self
|
333 |
+
)
|
334 |
+
self.evaluator.add_event_handler(
|
335 |
+
ignite.engine.Events.COMPLETED,
|
336 |
+
early_stopping
|
337 |
+
)
|
338 |
+
|
339 |
+
def _add_metrics_logger(self) -> None:
|
340 |
+
self.metric_logger=MetricLogger(
|
341 |
+
evaluator=self.evaluator
|
342 |
+
)
|
343 |
+
self.metric_logger.attach(self)
|
344 |
+
|
345 |
+
def _add_progress_bars(self) -> None:
|
346 |
+
trainer_pbar=ProgressBar()
|
347 |
+
evaluator_pbar=ProgressBar(
|
348 |
+
colour='green'
|
349 |
+
)
|
350 |
+
trainer_pbar.attach(
|
351 |
+
self,
|
352 |
+
output_transform=lambda output:{
|
353 |
+
'loss': torch.tensor(
|
354 |
+
[x['loss'] for x in output]
|
355 |
+
).mean()
|
356 |
+
}
|
357 |
+
)
|
358 |
+
evaluator_pbar.attach(self.evaluator)
|
359 |
+
|
360 |
+
def _add_metrics_saver(self) -> None:
|
361 |
+
metric_saver=MetricsSaver(
|
362 |
+
save_dir=self.config.out_dir,
|
363 |
+
metric_details='*',
|
364 |
+
batch_transform=self._get_meta_dict,
|
365 |
+
delimiter=','
|
366 |
+
)
|
367 |
+
metric_saver.attach(self.evaluator)
|
368 |
+
|
369 |
+
def _add_eval_loss(self)->None:
|
370 |
+
# TODO improve by adding this to val handlers
|
371 |
+
eval_loss_handler=ignite.metrics.Loss(
|
372 |
+
loss_fn=self.loss_function,
|
373 |
+
output_transform=lambda output: (
|
374 |
+
output[0]['pred'].unsqueeze(0), # add batch dim
|
375 |
+
output[0]['label'].argmax(0, keepdim=True).unsqueeze(0) # reverse one-hot, add batch dim
|
376 |
+
)
|
377 |
+
)
|
378 |
+
eval_loss_handler.attach(self.evaluator, 'eval_loss')
|
379 |
+
|
380 |
+
def _get_meta_dict(self, batch) -> list:
|
381 |
+
"Get dict of metadata from engine. Needed as `batch_transform`"
|
382 |
+
image_cols=self.config.data.image_cols
|
383 |
+
image_name=image_cols[0] if isinstance(image_cols, list) else image_cols
|
384 |
+
key=f'{image_name}_meta_dict'
|
385 |
+
return [item[key] for item in batch]
|
386 |
+
|
387 |
+
def load_checkpoint(self, checkpoint=None):
|
388 |
+
if not checkpoint:
|
389 |
+
# get name of last checkpoint
|
390 |
+
checkpoint = os.path.join(
|
391 |
+
self.config.model_dir,
|
392 |
+
f"network_{self.config.run_id}_key_metric={self.evaluator.state.best_metric:.4f}.pt"
|
393 |
+
)
|
394 |
+
self.network.load_state_dict(
|
395 |
+
torch.load(checkpoint)
|
396 |
+
)
|
397 |
+
|
398 |
+
def run(self, try_resume_from_checkpoint=True) -> None:
|
399 |
+
"""Run training, if `try_resume_from_checkpoint` tries to
|
400 |
+
load previous checkpoint stored at `self.config.model_dir`
|
401 |
+
"""
|
402 |
+
|
403 |
+
if try_resume_from_checkpoint:
|
404 |
+
checkpoints = [
|
405 |
+
os.path.join(
|
406 |
+
self.config.model_dir,
|
407 |
+
checkpoint_name
|
408 |
+
) for checkpoint_name in os.listdir(
|
409 |
+
self.config.model_dir
|
410 |
+
) if self.config.run_id in checkpoint_name
|
411 |
+
]
|
412 |
+
try:
|
413 |
+
checkpoint = sorted(checkpoints)[-1]
|
414 |
+
self.load_checkpoint(checkpoint)
|
415 |
+
print(f"resuming from previous checkpoint at {checkpoint}")
|
416 |
+
except: pass # train from scratch
|
417 |
+
|
418 |
+
# train the model
|
419 |
+
super().run()
|
420 |
+
|
421 |
+
# make metrics and losses more accessible
|
422 |
+
self.loss={
|
423 |
+
"iter": [_iter for _iter, _ in self.metric_logger.loss],
|
424 |
+
"loss": [_loss for _, _loss in self.metric_logger.loss],
|
425 |
+
"epoch": [_iter // self.state.epoch_length for _iter, _ in self.metric_logger.loss]
|
426 |
+
}
|
427 |
+
|
428 |
+
self.metrics={
|
429 |
+
k: [item[1] for item in self.metric_logger.metrics[k]] for k in
|
430 |
+
self.evaluator.state.metric_details.keys()
|
431 |
+
}
|
432 |
+
# pd.DataFrame(self.metrics).to_csv(f"{self.config.out_dir}/metric_logs.csv")
|
433 |
+
# pd.DataFrame(self.loss).to_csv(f"{self.config.out_dir}/loss_logs.csv")
|
434 |
+
|
435 |
+
def fit_one_cycle(self, try_resume_from_checkpoint=True) -> None:
|
436 |
+
"Run training using one-cycle-policy"
|
437 |
+
assert "FitOneCycle" not in self.schedulers, "FitOneCycle already added"
|
438 |
+
fit_one_cycle=monai.handlers.LrScheduleHandler(
|
439 |
+
torch.optim.lr_scheduler.OneCycleLR(
|
440 |
+
optimizer=self.optimizer,
|
441 |
+
max_lr=self.optimizer.param_groups[0]['lr'],
|
442 |
+
steps_per_epoch=self.state.epoch_length,
|
443 |
+
epochs=self.state.max_epochs
|
444 |
+
),
|
445 |
+
epoch_level=False,
|
446 |
+
name="FitOneCycle"
|
447 |
+
)
|
448 |
+
fit_one_cycle.attach(self)
|
449 |
+
self.schedulers += ["FitOneCycle"]
|
450 |
+
|
451 |
+
def reduce_lr_on_plateau(self,
|
452 |
+
try_resume_from_checkpoint=True,
|
453 |
+
factor=0.1,
|
454 |
+
patience=10,
|
455 |
+
min_lr=1e-10,
|
456 |
+
verbose=True) -> None:
|
457 |
+
"Reduce learning rate by `factor` every `patience` epochs if kex_metric does not improve"
|
458 |
+
assert "ReduceLROnPlateau" not in self.schedulers, "ReduceLROnPlateau already added"
|
459 |
+
reduce_lr_on_plateau=monai.handlers.LrScheduleHandler(
|
460 |
+
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
461 |
+
optimizer=self.optimizer,
|
462 |
+
factor=factor,
|
463 |
+
patience=patience,
|
464 |
+
min_lr=min_lr,
|
465 |
+
verbose=verbose
|
466 |
+
),
|
467 |
+
print_lr=True,
|
468 |
+
name='ReduceLROnPlateau',
|
469 |
+
epoch_level=True,
|
470 |
+
step_transform=lambda engine: engine.state.metrics[engine.state.key_metric_name],
|
471 |
+
)
|
472 |
+
reduce_lr_on_plateau.attach(self.evaluator)
|
473 |
+
self.schedulers += ["ReduceLROnPlateau"]
|
474 |
+
|
475 |
+
def evaluate(self, checkpoint=None, dataloader=None):
|
476 |
+
"Run evaluation with best saved checkpoint"
|
477 |
+
self.load_checkpoint(checkpoint)
|
478 |
+
if dataloader:
|
479 |
+
self.evaluator.set_data(dataloader)
|
480 |
+
self.evaluator.state.epoch_length=len(dataloader)
|
481 |
+
self.evaluator.run()
|
482 |
+
print(f"metrics saved to {self.config.out_dir}")
|
prostate/transforms.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# create transforms for training, validation and test dataset
|
2 |
+
|
3 |
+
## TODO: Make Transforms more dynamic by directly building from config args
|
4 |
+
## Maybe like this
|
5 |
+
## TFM_NAME=config.transforms.keys()[0]
|
6 |
+
## tfm_fun=getattr(monai.transforms, TFM_NAME)
|
7 |
+
## tmfs+=[tfms_fun(keys=image+cols, **config.transforms[TFM_NAME], prob=prob, mode=mode)
|
8 |
+
|
9 |
+
|
10 |
+
## ---------- imports ----------
|
11 |
+
import os
|
12 |
+
# only import of base transforms, others are imported as needed
|
13 |
+
from monai.utils.enums import CommonKeys
|
14 |
+
from monai.transforms import (
|
15 |
+
Activationsd,
|
16 |
+
AsDiscreted,
|
17 |
+
Compose,
|
18 |
+
ConcatItemsd,
|
19 |
+
KeepLargestConnectedComponentd,
|
20 |
+
LoadImaged,
|
21 |
+
EnsureChannelFirstd,
|
22 |
+
EnsureTyped,
|
23 |
+
SaveImaged,
|
24 |
+
ScaleIntensityd,
|
25 |
+
NormalizeIntensityd
|
26 |
+
)
|
27 |
+
# images should be interploated with `bilinear` but masks with `nearest`
|
28 |
+
|
29 |
+
## ---------- base transforms ----------
|
30 |
+
# applied everytime
|
31 |
+
def get_base_transforms(
|
32 |
+
config: dict,
|
33 |
+
minv: int=0,
|
34 |
+
maxv: int=1
|
35 |
+
)->list:
|
36 |
+
|
37 |
+
tfms=[]
|
38 |
+
tfms+=[LoadImaged(keys=config.data.image_cols+config.data.label_cols)]
|
39 |
+
tfms+=[EnsureChannelFirstd(keys=config.data.image_cols+config.data.label_cols)]
|
40 |
+
if config.transforms.spacing:
|
41 |
+
from monai.transforms import Spacingd
|
42 |
+
tfms+=[
|
43 |
+
Spacingd(
|
44 |
+
keys=config.data.image_cols+config.data.label_cols,
|
45 |
+
pixdim=config.transforms.spacing,
|
46 |
+
mode=config.transforms.mode
|
47 |
+
)
|
48 |
+
]
|
49 |
+
if config.transforms.orientation:
|
50 |
+
from monai.transforms import Orientationd
|
51 |
+
tfms+=[
|
52 |
+
Orientationd(
|
53 |
+
keys=config.data.image_cols+config.data.label_cols,
|
54 |
+
axcodes=config.transforms.orientation
|
55 |
+
)
|
56 |
+
]
|
57 |
+
tfms+=[
|
58 |
+
ScaleIntensityd(
|
59 |
+
keys=config.data.image_cols,
|
60 |
+
minv=minv,
|
61 |
+
maxv=maxv
|
62 |
+
)
|
63 |
+
]
|
64 |
+
tfms+=[NormalizeIntensityd(keys=config.data.image_cols)]
|
65 |
+
return tfms
|
66 |
+
|
67 |
+
## ---------- train transforms ----------
|
68 |
+
|
69 |
+
def get_train_transforms(config: dict):
|
70 |
+
tfms=get_base_transforms(config=config)
|
71 |
+
|
72 |
+
# ---------- specific transforms for mri ----------
|
73 |
+
if 'rand_bias_field' in config.transforms.keys():
|
74 |
+
from monai.transforms import RandBiasFieldd
|
75 |
+
args=config.transforms.rand_bias_field
|
76 |
+
tfms+=[
|
77 |
+
RandBiasFieldd(
|
78 |
+
keys=config.data.image_cols,
|
79 |
+
degree=args['degree'],
|
80 |
+
coeff_range=args['coeff_range'],
|
81 |
+
prob=config.transforms.prob
|
82 |
+
)
|
83 |
+
]
|
84 |
+
|
85 |
+
if 'rand_gaussian_smooth' in config.transforms.keys():
|
86 |
+
from monai.transforms import RandGaussianSmoothd
|
87 |
+
args=config.transforms.rand_gaussian_smooth
|
88 |
+
tfms+=[
|
89 |
+
RandGaussianSmoothd(
|
90 |
+
keys=config.data.image_cols,
|
91 |
+
sigma_x=args['sigma_x'],
|
92 |
+
sigma_y=args['sigma_y'],
|
93 |
+
sigma_z=args['sigma_z'],
|
94 |
+
prob=config.transforms.prob
|
95 |
+
)
|
96 |
+
]
|
97 |
+
|
98 |
+
if 'rand_gibbs_nose' in config.transforms.keys():
|
99 |
+
from monai.transforms import RandGibbsNoised
|
100 |
+
args=config.transforms.rand_gibbs_nose
|
101 |
+
tfms+=[
|
102 |
+
RandGibbsNoised(
|
103 |
+
keys=config.data.image_cols,
|
104 |
+
alpha=args['alpha'],
|
105 |
+
prob=config.transforms.prob
|
106 |
+
)
|
107 |
+
]
|
108 |
+
|
109 |
+
# ---------- affine transforms ----------
|
110 |
+
|
111 |
+
if 'rand_affine' in config.transforms.keys():
|
112 |
+
from monai.transforms import RandAffined
|
113 |
+
args=config.transforms.rand_affine
|
114 |
+
tfms+=[
|
115 |
+
RandAffined(
|
116 |
+
keys=config.data.image_cols+config.data.label_cols,
|
117 |
+
rotate_range=args['rotate_range'],
|
118 |
+
shear_range=args['shear_range'],
|
119 |
+
translate_range=args['translate_range'],
|
120 |
+
mode=config.transforms.mode,
|
121 |
+
prob=config.transforms.prob
|
122 |
+
)
|
123 |
+
]
|
124 |
+
|
125 |
+
if 'rand_rotate90' in config.transforms.keys():
|
126 |
+
from monai.transforms import RandRotate90d
|
127 |
+
args=config.transforms.rand_rotate90
|
128 |
+
tfms+=[
|
129 |
+
RandRotate90d(
|
130 |
+
keys=config.data.image_cols+config.data.label_cols,
|
131 |
+
spatial_axes=args['spatial_axes'],
|
132 |
+
prob=config.transforms.prob
|
133 |
+
)
|
134 |
+
]
|
135 |
+
|
136 |
+
if 'rand_rotate' in config.transforms.keys():
|
137 |
+
from monai.transforms import RandRotated
|
138 |
+
args=config.transforms.rand_rotate
|
139 |
+
tfms+=[
|
140 |
+
RandRotated(
|
141 |
+
keys=config.data.image_cols+config.data.label_cols,
|
142 |
+
range_x=args['range_x'],
|
143 |
+
range_y=args['range_y'],
|
144 |
+
range_z=args['range_z'],
|
145 |
+
mode=config.transforms.mode,
|
146 |
+
prob=config.transforms.prob
|
147 |
+
)
|
148 |
+
]
|
149 |
+
|
150 |
+
if 'rand_elastic' in config.transforms.keys():
|
151 |
+
if config['ndim'] == 3:
|
152 |
+
from monai.transforms import Rand3DElasticd as RandElasticd
|
153 |
+
elif config['ndim'] == 2:
|
154 |
+
from monai.transforms import Rand2DElasticd as RandElasticd
|
155 |
+
args=config.transforms.rand_elastic
|
156 |
+
tfms+=[
|
157 |
+
RandElasticd(
|
158 |
+
keys=config.data.image_cols+config.data.label_cols,
|
159 |
+
sigma_range=args['sigma_range'],
|
160 |
+
magnitude_range=args['magnitude_range'],
|
161 |
+
rotate_range=args['rotate_range'],
|
162 |
+
shear_range=args['shear_range'],
|
163 |
+
translate_range=args['translate_range'],
|
164 |
+
mode=config.transforms.mode,
|
165 |
+
prob=config.transforms.prob
|
166 |
+
)
|
167 |
+
]
|
168 |
+
|
169 |
+
if 'rand_zoom' in config.transforms.keys():
|
170 |
+
from monai.transforms import RandZoomd
|
171 |
+
args=config.transforms.rand_zoom
|
172 |
+
tfms+=[
|
173 |
+
RandZoomd(
|
174 |
+
keys=config.data.image_cols+config.data.label_cols,
|
175 |
+
min_zoom=args['min'],
|
176 |
+
max_zoom=args['max'],
|
177 |
+
mode=['area' if x == 'bilinear' else x for x in config.transforms.mode],
|
178 |
+
prob=config.transforms.prob
|
179 |
+
)
|
180 |
+
]
|
181 |
+
|
182 |
+
# ---------- random cropping, very effective for large images ----------
|
183 |
+
# RandCropByPosNegLabeld is not advisable for data with missing lables
|
184 |
+
# e.g., segmentation of carcinomas which are not present on all images
|
185 |
+
# thus fallback to RandSpatialCropSamplesd. Completly replacing Cropping
|
186 |
+
# by just resizing could be discussed, but I believe it is not beneficial
|
187 |
+
# For the first version, this is an ungly hack. For the second version,
|
188 |
+
# a better verion for transforms should be written.
|
189 |
+
|
190 |
+
if 'rand_crop_pos_neg_label' in config.transforms.keys():
|
191 |
+
from monai.transforms import RandCropByPosNegLabeld
|
192 |
+
args=config.transforms.rand_crop_pos_neg_label
|
193 |
+
tfms+=[
|
194 |
+
RandCropByPosNegLabeld(
|
195 |
+
keys=config.data.image_cols+config.data.label_cols,
|
196 |
+
label_key=config.data.label_cols[0],
|
197 |
+
spatial_size=args['spatial_size'],
|
198 |
+
pos=args['pos'],
|
199 |
+
neg=args['neg'],
|
200 |
+
num_samples=args['num_samples'],
|
201 |
+
image_key=config.data.image_cols[0],
|
202 |
+
image_threshold=0,
|
203 |
+
)
|
204 |
+
]
|
205 |
+
|
206 |
+
elif 'rand_spatial_crop_samples' in config.transforms.keys():
|
207 |
+
from monai.transforms import RandSpatialCropSamplesd
|
208 |
+
args=config.transforms.rand_spatial_crop_samples
|
209 |
+
tfms+=[
|
210 |
+
RandSpatialCropSamplesd(
|
211 |
+
keys=config.data.image_cols+config.data.label_cols,
|
212 |
+
roi_size=args['roi_size'],
|
213 |
+
random_size=False,
|
214 |
+
num_samples=args['num_samples'],
|
215 |
+
)
|
216 |
+
]
|
217 |
+
|
218 |
+
else:
|
219 |
+
raise ValueError('Either `rand_crop_pos_neg_label` or `rand_spatial_crop_samples` '\
|
220 |
+
'need to be specified')
|
221 |
+
|
222 |
+
# ---------- intensity transforms ----------
|
223 |
+
|
224 |
+
if 'gaussian_noise' in config.transforms.keys():
|
225 |
+
from monai.transforms import RandGaussianNoised
|
226 |
+
args=config.transforms.gaussian_noise
|
227 |
+
tfms+=[
|
228 |
+
RandGaussianNoised(
|
229 |
+
keys=config.data.image_cols,
|
230 |
+
mean=args['mean'],
|
231 |
+
std=args['std'],
|
232 |
+
prob=config.transforms.prob
|
233 |
+
)
|
234 |
+
]
|
235 |
+
|
236 |
+
if 'shift_intensity' in config.transforms.keys():
|
237 |
+
from monai.transforms import RandShiftIntensityd
|
238 |
+
args=config.transforms.shift_intensity
|
239 |
+
tfms+=[
|
240 |
+
RandShiftIntensityd(
|
241 |
+
keys=config.data.image_cols,
|
242 |
+
offsets=args['offsets'],
|
243 |
+
prob=config.transforms.prob
|
244 |
+
)
|
245 |
+
]
|
246 |
+
|
247 |
+
if 'gaussian_sharpen' in config.transforms.keys():
|
248 |
+
from monai.transforms import RandGaussianSharpend
|
249 |
+
args=config.transforms.gaussian_sharpen
|
250 |
+
tfms+=[
|
251 |
+
RandGaussianSharpend(
|
252 |
+
keys=config.data.image_cols,
|
253 |
+
sigma1_x=args['sigma1_x'],
|
254 |
+
sigma1_y=args['sigma1_y'],
|
255 |
+
sigma1_z=args['sigma1_z'],
|
256 |
+
sigma2_x=args['sigma2_x'],
|
257 |
+
sigma2_y=args['sigma2_y'],
|
258 |
+
sigma2_z=args['sigma2_z'],
|
259 |
+
alpha=args['alpha'],
|
260 |
+
prob=config.transforms.prob
|
261 |
+
)
|
262 |
+
]
|
263 |
+
|
264 |
+
if 'adjust_contrast' in config.transforms.keys():
|
265 |
+
from monai.transforms import RandAdjustContrastd
|
266 |
+
args=config.transforms.adjust_contrast
|
267 |
+
tfms+=[
|
268 |
+
RandAdjustContrastd(
|
269 |
+
keys=config.data.image_cols,
|
270 |
+
gamma=args['gamma'],
|
271 |
+
prob=config.transforms.prob
|
272 |
+
)
|
273 |
+
]
|
274 |
+
|
275 |
+
# Concat mutlisequence data to single Tensors on the ChannelDim
|
276 |
+
# Rename images to `CommonKeys.IMAGE` and labels to `CommonKeys.LABELS`
|
277 |
+
# for more compatibility with monai.engines
|
278 |
+
|
279 |
+
tfms+=[
|
280 |
+
ConcatItemsd(
|
281 |
+
keys=config.data.image_cols,
|
282 |
+
name=CommonKeys.IMAGE,
|
283 |
+
dim=0
|
284 |
+
)
|
285 |
+
]
|
286 |
+
|
287 |
+
tfms+=[
|
288 |
+
ConcatItemsd(
|
289 |
+
keys=config.data.label_cols,
|
290 |
+
name=CommonKeys.LABEL,
|
291 |
+
dim=0
|
292 |
+
)
|
293 |
+
]
|
294 |
+
|
295 |
+
return Compose(tfms)
|
296 |
+
|
297 |
+
## ---------- valid transforms ----------
|
298 |
+
|
299 |
+
def get_val_transforms(config: dict):
|
300 |
+
tfms=get_base_transforms(config=config)
|
301 |
+
tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
|
302 |
+
tfms+=[
|
303 |
+
ConcatItemsd(
|
304 |
+
keys=config.data.image_cols,
|
305 |
+
name=CommonKeys.IMAGE,
|
306 |
+
dim=0
|
307 |
+
)
|
308 |
+
]
|
309 |
+
|
310 |
+
tfms+=[
|
311 |
+
ConcatItemsd(
|
312 |
+
keys=config.data.label_cols,
|
313 |
+
name=CommonKeys.LABEL,
|
314 |
+
dim=0
|
315 |
+
)
|
316 |
+
]
|
317 |
+
|
318 |
+
return Compose(tfms)
|
319 |
+
|
320 |
+
## ---------- test transforms ----------
|
321 |
+
# same as valid transforms
|
322 |
+
|
323 |
+
def get_test_transforms(config: dict):
|
324 |
+
tfms=get_base_transforms(config=config)
|
325 |
+
tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
|
326 |
+
tfms+=[
|
327 |
+
ConcatItemsd(
|
328 |
+
keys=config.data.image_cols,
|
329 |
+
name=CommonKeys.IMAGE,
|
330 |
+
dim=0
|
331 |
+
)
|
332 |
+
]
|
333 |
+
|
334 |
+
tfms+=[
|
335 |
+
ConcatItemsd(
|
336 |
+
keys=config.data.label_cols,
|
337 |
+
name=CommonKeys.LABEL,
|
338 |
+
dim=0
|
339 |
+
)
|
340 |
+
]
|
341 |
+
|
342 |
+
return Compose(tfms)
|
343 |
+
|
344 |
+
|
345 |
+
def get_val_post_transforms(config: dict):
|
346 |
+
tfms=[EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
|
347 |
+
AsDiscreted(
|
348 |
+
keys=CommonKeys.PRED,
|
349 |
+
argmax=True,
|
350 |
+
to_onehot=config.model.out_channels,
|
351 |
+
num_classes=config.model.out_channels
|
352 |
+
),
|
353 |
+
AsDiscreted(
|
354 |
+
keys=CommonKeys.LABEL,
|
355 |
+
to_onehot=config.model.out_channels,
|
356 |
+
num_classes=config.model.out_channels
|
357 |
+
),
|
358 |
+
KeepLargestConnectedComponentd(
|
359 |
+
keys=CommonKeys.PRED,
|
360 |
+
applied_labels=list(range(1, config.model.out_channels))
|
361 |
+
),
|
362 |
+
]
|
363 |
+
return Compose(tfms)
|
prostate/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
import monai
|
5 |
+
import munch
|
6 |
+
|
7 |
+
def load_config(fn: str='config.yaml'):
|
8 |
+
"Load config from YAML and return a serialized dictionary object"
|
9 |
+
with open(fn, 'r') as stream:
|
10 |
+
config=yaml.safe_load(stream)
|
11 |
+
config=munch.munchify(config)
|
12 |
+
|
13 |
+
if not config.overwrite:
|
14 |
+
i=1
|
15 |
+
while os.path.exists(config.run_id+f'_{i}'):
|
16 |
+
i+=1
|
17 |
+
config.run_id+=f'_{i}'
|
18 |
+
|
19 |
+
config.out_dir = os.path.join(config.run_id, config.out_dir)
|
20 |
+
config.log_dir = os.path.join(config.run_id, config.log_dir)
|
21 |
+
|
22 |
+
if not isinstance(config.data.image_cols, (tuple, list)):
|
23 |
+
config.data.image_cols=[config.data.image_cols]
|
24 |
+
if not isinstance(config.data.label_cols, (tuple, list)):
|
25 |
+
config.data.label_cols=[config.data.label_cols]
|
26 |
+
|
27 |
+
config.transforms.mode=('bilinear', ) * len(config.data.image_cols) + \
|
28 |
+
('nearest', ) * len(config.data.label_cols)
|
29 |
+
return config
|
30 |
+
|
31 |
+
|
32 |
+
def num_workers():
|
33 |
+
"Get max supported workers -2 for multiprocessing"
|
34 |
+
import resource
|
35 |
+
import multiprocessing
|
36 |
+
|
37 |
+
# first check for max number of open files allowed on system
|
38 |
+
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
39 |
+
n_workers=multiprocessing.cpu_count() - 2
|
40 |
+
# giving each worker at least 256 open processes should allow them to run smoothly
|
41 |
+
max_workers = soft_limit // 256
|
42 |
+
if max_workers < n_workers:
|
43 |
+
print(
|
44 |
+
"Will not use all available workers as number of allowed open files is to small"
|
45 |
+
"to ensure smooth multiprocessing. Current limits are:\n"
|
46 |
+
f"\t soft_limit: {soft_limit}\n"
|
47 |
+
f"\t hard_limit: {hard_limit}\n"
|
48 |
+
"try increasing the limits to at least {256*n_workers}."
|
49 |
+
"See https://superuser.com/questions/1200539/cannot-increase-open-file-limit-past-4096-ubuntu"
|
50 |
+
"for more details"
|
51 |
+
)
|
52 |
+
return max_workers
|
53 |
+
|
54 |
+
return n_workers
|
55 |
+
|
56 |
+
USE_AMP=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False
|
prostate/viewer.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import ipywidgets
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from IPython.display import display
|
6 |
+
from itertools import chain, islice
|
7 |
+
from ipywidgets import interactive, widgets
|
8 |
+
|
9 |
+
def _create_label(text:str)->ipywidgets.widgets.Label:
|
10 |
+
"Create label widget"
|
11 |
+
|
12 |
+
label = widgets.Label(
|
13 |
+
text,
|
14 |
+
layout=widgets.Layout(
|
15 |
+
width='100%',
|
16 |
+
display='flex',
|
17 |
+
justify_content="center"
|
18 |
+
)
|
19 |
+
)
|
20 |
+
return label
|
21 |
+
|
22 |
+
def _create_slider(
|
23 |
+
slider_min: int,
|
24 |
+
slider_max: int,
|
25 |
+
value: int,
|
26 |
+
step: int=1,
|
27 |
+
description:str ='',
|
28 |
+
continuous_update: bool=True,
|
29 |
+
readout: bool=False,
|
30 |
+
slider_type: str='IntSlider',
|
31 |
+
**kwargs)->ipywidgets.widgets:
|
32 |
+
"Create slider widget"
|
33 |
+
|
34 |
+
slider = getattr(widgets, slider_type)(
|
35 |
+
min=slider_min,
|
36 |
+
max=slider_max,
|
37 |
+
step=step,
|
38 |
+
value=value,
|
39 |
+
description=description,
|
40 |
+
continuous_update=continuous_update,
|
41 |
+
readout = readout,
|
42 |
+
layout=widgets.Layout(width='99%', min_width='200px'),
|
43 |
+
style={'description_width': 'initial'},
|
44 |
+
**kwargs
|
45 |
+
)
|
46 |
+
return slider
|
47 |
+
|
48 |
+
def _create_button(description:str)->ipywidgets.widgets.Button:
|
49 |
+
"Create button widget"
|
50 |
+
button = widgets.Button(
|
51 |
+
description=description,
|
52 |
+
layout=widgets.Layout(
|
53 |
+
width='95%',
|
54 |
+
margin='5px 5px'
|
55 |
+
)
|
56 |
+
)
|
57 |
+
return button
|
58 |
+
|
59 |
+
def _create_togglebutton(description: str,
|
60 |
+
value: int,
|
61 |
+
**kwargs)->ipywidgets.widgets.Button:
|
62 |
+
"Create toggle button widget"
|
63 |
+
button = widgets.ToggleButton(
|
64 |
+
description=description,
|
65 |
+
value = value,
|
66 |
+
layout=widgets.Layout(
|
67 |
+
width='95%',
|
68 |
+
margin='5px 5px'
|
69 |
+
), **kwargs
|
70 |
+
)
|
71 |
+
return button
|
72 |
+
|
73 |
+
|
74 |
+
class BasicViewer():
|
75 |
+
""" Base class for viewing TensorDicom3D objects.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
x: main image object to view as rank 3 tensor
|
79 |
+
y: either a segmentation mask as as rank 3 tensor or a label as str.
|
80 |
+
prediction: a class predicton as str
|
81 |
+
description: description of the whole image
|
82 |
+
figsize: size of image, passed as plotting argument
|
83 |
+
cmap: colormap for the image
|
84 |
+
Returns:
|
85 |
+
Instance of BasicViewer
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, x:torch.Tensor, y=None, prediction:str=None, description: str=None,
|
89 |
+
figsize=(3, 3), cmap:str='bone'):
|
90 |
+
assert x.ndim == 3, f"x.ndim needs to be equal to but is {x.ndim}"
|
91 |
+
if isinstance(y, torch.Tensor):
|
92 |
+
assert x.shape == y.shape, f"Shapes of x {x.shape} and y {y.shape} do not match"
|
93 |
+
self.x=x
|
94 |
+
self.y=y
|
95 |
+
self.prediction=prediction
|
96 |
+
self.description=description
|
97 |
+
self.figsize=figsize
|
98 |
+
self.cmap=cmap
|
99 |
+
self.with_mask = isinstance(y, torch.Tensor)
|
100 |
+
self.slice_range = (1, len(x)) # len(x) == im.shape[0]
|
101 |
+
|
102 |
+
def _plot_slice(self, im_slice, with_mask, px_range):
|
103 |
+
"Plot slice of image"
|
104 |
+
fig, ax = plt.subplots(1, 1, figsize=self.figsize)
|
105 |
+
ax.imshow(self.x[im_slice-1, :, :].clip(*px_range), cmap=self.cmap)
|
106 |
+
if isinstance(self.y, (torch.Tensor)) and with_mask:
|
107 |
+
ax.imshow(self.y[im_slice-1, :, :], cmap='jet', alpha = 0.25)
|
108 |
+
plt.axis('off')
|
109 |
+
ax.set_xticks([])
|
110 |
+
ax.set_yticks([])
|
111 |
+
plt.show()
|
112 |
+
|
113 |
+
def _create_image_box(self, figsize):
|
114 |
+
"Create widget items, order them in item_box and generate view box"
|
115 |
+
items = []
|
116 |
+
|
117 |
+
if self.description: plot_description = _create_label(self.description)
|
118 |
+
|
119 |
+
if isinstance(self.y, str):
|
120 |
+
label = f'{self.y} | {self.prediction}' if self.prediction else self.y
|
121 |
+
if self.prediction:
|
122 |
+
font_color = 'green' if self.y == self.prediction else 'red'
|
123 |
+
y_label = _create_label(r'\(\color{' + font_color + '} {' + label + '}\)')
|
124 |
+
else:
|
125 |
+
y_label = _create_label(label)
|
126 |
+
else: y_label = _create_label(' ')
|
127 |
+
|
128 |
+
slice_slider = _create_slider(
|
129 |
+
slider_min = min(self.slice_range),
|
130 |
+
slider_max = max(self.slice_range),
|
131 |
+
value = max(self.slice_range)//2,
|
132 |
+
readout = True)
|
133 |
+
|
134 |
+
toggle_mask_button = _create_togglebutton('Show Mask', True)
|
135 |
+
|
136 |
+
range_slider = _create_slider(
|
137 |
+
slider_min = self.x.min().numpy(),
|
138 |
+
slider_max = self.x.max().numpy(),
|
139 |
+
value = [self.x.min().numpy(), self.x.max().numpy()],
|
140 |
+
slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
|
141 |
+
step = 0.01 if torch.is_floating_point(self.x) else 1,
|
142 |
+
readout=True)
|
143 |
+
|
144 |
+
image_output = widgets.interactive_output(
|
145 |
+
f = self._plot_slice,
|
146 |
+
controls = {'im_slice': slice_slider,
|
147 |
+
'with_mask': toggle_mask_button,
|
148 |
+
'px_range': range_slider})
|
149 |
+
|
150 |
+
image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
|
151 |
+
image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
|
152 |
+
|
153 |
+
if self.description: items.append(plot_description)
|
154 |
+
items.append(y_label)
|
155 |
+
items.append(range_slider)
|
156 |
+
items.append(image_output)
|
157 |
+
if isinstance(self.y, torch.Tensor):
|
158 |
+
slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
|
159 |
+
items.append(slice_slider)
|
160 |
+
|
161 |
+
image_box=widgets.VBox(
|
162 |
+
items,
|
163 |
+
layout = widgets.Layout(
|
164 |
+
border = 'none',
|
165 |
+
margin = '10px 5px 0px 0px',
|
166 |
+
padding = '5px'))
|
167 |
+
|
168 |
+
return image_box
|
169 |
+
|
170 |
+
def _generate_views(self):
|
171 |
+
image_box = self._create_image_box(self.figsize)
|
172 |
+
self.box = widgets.HBox(children=[image_box])
|
173 |
+
|
174 |
+
@property
|
175 |
+
def image_box(self):
|
176 |
+
return self._create_image_box(self.figsize)
|
177 |
+
|
178 |
+
def show(self):
|
179 |
+
self._generate_views()
|
180 |
+
plt.style.use('default')
|
181 |
+
display(self.box)
|
182 |
+
|
183 |
+
|
184 |
+
class DicomExplorer(BasicViewer):
|
185 |
+
""" DICOM viewer for basic image analysis inside iPython notebooks.
|
186 |
+
Can display a single 3D volume together with a segmentation mask, a histogram
|
187 |
+
of voxel/pixel values and some summary statistics.
|
188 |
+
Allows simple windowing by clipping the pixel/voxel values to a region, which
|
189 |
+
can be manually specified.
|
190 |
+
|
191 |
+
"""
|
192 |
+
|
193 |
+
vbox_layout = widgets.Layout(
|
194 |
+
margin = '10px 5px 5px 5px',
|
195 |
+
padding = '5px',
|
196 |
+
display='flex',
|
197 |
+
flex_flow='column',
|
198 |
+
align_items='center',
|
199 |
+
min_width = '250px')
|
200 |
+
|
201 |
+
def _plot_hist(self, px_range):
|
202 |
+
x = self.x.numpy().flatten()
|
203 |
+
fig, ax = plt.subplots(figsize=self.figsize)
|
204 |
+
N, bins, patches = plt.hist(x, 100, color='grey')
|
205 |
+
lwr = int(px_range[0] * 100/max(x))
|
206 |
+
upr = int(np.ceil(px_range[1] * 100/max(x)))
|
207 |
+
|
208 |
+
for i in range(0,lwr):
|
209 |
+
patches[i].set_facecolor('grey' if lwr > 0 else 'darkblue')
|
210 |
+
for i in range(lwr, upr):
|
211 |
+
patches[i].set_facecolor('darkblue')
|
212 |
+
for i in range(upr,100):
|
213 |
+
patches[i].set_facecolor('grey' if upr < 100 else 'darkblue')
|
214 |
+
|
215 |
+
plt.show()
|
216 |
+
|
217 |
+
def _image_summary(self, px_range):
|
218 |
+
x = self.x.clip(*px_range)
|
219 |
+
|
220 |
+
diffs = x - x.mean()
|
221 |
+
var = torch.mean(torch.pow(diffs, 2.0))
|
222 |
+
std = torch.pow(var, 0.5)
|
223 |
+
zscores = diffs / std
|
224 |
+
skews = torch.mean(torch.pow(zscores, 3.0))
|
225 |
+
kurt = torch.mean(torch.pow(zscores, 4.0)) - 3.0
|
226 |
+
|
227 |
+
table = f'Statistics:\n' + \
|
228 |
+
f' Mean px value: {x.mean()} \n' + \
|
229 |
+
f' Std of px values: {x.std()} \n' + \
|
230 |
+
f' Min px value: {x.min()} \n' + \
|
231 |
+
f' Max px value: {x.max()} \n' + \
|
232 |
+
f' Median px value: {x.median()} \n' + \
|
233 |
+
f' Skewness: {skews} \n' + \
|
234 |
+
f' Kurtosis: {kurt} \n\n' + \
|
235 |
+
f'Tensor properties \n' + \
|
236 |
+
f' Tensor shape: {tuple(x.shape)}\n' + \
|
237 |
+
f' Tensor dtype: {x.dtype}'
|
238 |
+
print(table)
|
239 |
+
|
240 |
+
def _generate_views(self):
|
241 |
+
|
242 |
+
slice_slider = _create_slider(
|
243 |
+
slider_min = min(self.slice_range),
|
244 |
+
slider_max = max(self.slice_range),
|
245 |
+
value = max(self.slice_range)//2,
|
246 |
+
readout = True)
|
247 |
+
|
248 |
+
toggle_mask_button = _create_togglebutton('Show Mask', True)
|
249 |
+
|
250 |
+
range_slider = _create_slider(
|
251 |
+
slider_min = self.x.min().numpy(),
|
252 |
+
slider_max = self.x.max().numpy(),
|
253 |
+
value = [self.x.min().numpy(), self.x.max().numpy()],
|
254 |
+
continuous_update=False,
|
255 |
+
slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
|
256 |
+
step = 0.01 if torch.is_floating_point(self.x) else 1)
|
257 |
+
|
258 |
+
image_output = widgets.interactive_output(
|
259 |
+
f = self._plot_slice,
|
260 |
+
controls = {'im_slice': slice_slider,
|
261 |
+
'with_mask': toggle_mask_button,
|
262 |
+
'px_range': range_slider})
|
263 |
+
|
264 |
+
image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
|
265 |
+
image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
|
266 |
+
|
267 |
+
if isinstance(self.y, torch.Tensor):
|
268 |
+
slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
|
269 |
+
|
270 |
+
hist_output = widgets.interactive_output(
|
271 |
+
f = self._plot_hist,
|
272 |
+
controls = {'px_range': range_slider})
|
273 |
+
|
274 |
+
hist_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
|
275 |
+
hist_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
|
276 |
+
|
277 |
+
toggle_mask_button = _create_togglebutton('Show Mask', True)
|
278 |
+
|
279 |
+
table_output = widgets.interactive_output(
|
280 |
+
f = self._image_summary,
|
281 |
+
controls = {'px_range': range_slider})
|
282 |
+
|
283 |
+
table_box = widgets.VBox([table_output], layout=self.vbox_layout)
|
284 |
+
|
285 |
+
hist_box = widgets.VBox(
|
286 |
+
[hist_output, range_slider],
|
287 |
+
layout=self.vbox_layout)
|
288 |
+
|
289 |
+
image_box = widgets.VBox(
|
290 |
+
[image_output, slice_slider],
|
291 |
+
layout=self.vbox_layout)
|
292 |
+
|
293 |
+
self.box = widgets.HBox(
|
294 |
+
[image_box, hist_box, table_box],
|
295 |
+
layout = widgets.Layout(
|
296 |
+
border = 'solid 1px lightgrey',
|
297 |
+
margin = '10px 5px 0px 0px',
|
298 |
+
padding = '5px',
|
299 |
+
width = f'{self.figsize[1]*2 + 3}in'))
|
300 |
+
|
301 |
+
|
302 |
+
class ListViewer(object):
|
303 |
+
""" Display multipple images with their masks or labels/predictions.
|
304 |
+
Arguments:
|
305 |
+
x (tuple, list): Tensor objects to view
|
306 |
+
y (tuple, list): Tensor objects (in case of segmentation task) or class labels as string.
|
307 |
+
predictions (str): Class predictions
|
308 |
+
cmap: colormap for display of `x`
|
309 |
+
max_n: maximum number of items to display
|
310 |
+
"""
|
311 |
+
|
312 |
+
def __init__(self, x:(list, tuple), y=None, prediction:str=None, description: str=None,
|
313 |
+
figsize=(4, 4), cmap:str='bone', max_n = 9):
|
314 |
+
self.slice_range = (1, len(x))
|
315 |
+
x = x[0:max_n]
|
316 |
+
if y: y = y[0:max_n]
|
317 |
+
self.x=x
|
318 |
+
self.y=y
|
319 |
+
self.prediction=prediction
|
320 |
+
self.description=description
|
321 |
+
self.figsize=figsize
|
322 |
+
self.cmap=cmap
|
323 |
+
self.max_n=max_n
|
324 |
+
|
325 |
+
def _generate_views(self):
|
326 |
+
n_images = len(self.x)
|
327 |
+
image_grid, image_list = [], []
|
328 |
+
|
329 |
+
for i in range(0, n_images):
|
330 |
+
image = self.x[i]
|
331 |
+
mask = self.y[i] if isinstance(self.y, list) else None
|
332 |
+
pred = self.prediction[i] if self.prediction else None
|
333 |
+
|
334 |
+
image_list.append(
|
335 |
+
BasicViewer(
|
336 |
+
x = image,
|
337 |
+
y = mask,
|
338 |
+
prediction = pred,
|
339 |
+
figsize = self.figsize,
|
340 |
+
cmap = self.cmap)
|
341 |
+
.image_box)
|
342 |
+
|
343 |
+
if (i+1) % np.ceil(np.sqrt(n_images)) == 0 or i == n_images - 1:
|
344 |
+
image_grid.append(widgets.HBox(image_list))
|
345 |
+
image_list = []
|
346 |
+
|
347 |
+
self.box = widgets.VBox(children=image_grid)
|
348 |
+
|
349 |
+
def show(self):
|
350 |
+
self._generate_views()
|
351 |
+
plt.style.use('default')
|
352 |
+
display(self.box)
|