Spaces:
Runtime error
Runtime error
Delete prostate
Browse files- prostate/data.py +0 -229
- prostate/loss.py +0 -23
- prostate/model.py +0 -16
- prostate/optimizer.py +0 -37
- prostate/report.py +0 -184
- prostate/train.py +0 -482
- prostate/transforms.py +0 -363
- prostate/utils.py +0 -56
- prostate/viewer.py +0 -352
prostate/data.py
DELETED
@@ -1,229 +0,0 @@
|
|
1 |
-
# create dataloaders form csv file
|
2 |
-
|
3 |
-
## ---------- imports ----------
|
4 |
-
import os
|
5 |
-
import torch
|
6 |
-
import shutil
|
7 |
-
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
-
from typing import Union
|
10 |
-
from monai.utils import first
|
11 |
-
from functools import partial
|
12 |
-
from collections import namedtuple
|
13 |
-
from monai.data import DataLoader as MonaiDataLoader
|
14 |
-
|
15 |
-
from . import transforms
|
16 |
-
from .utils import num_workers
|
17 |
-
|
18 |
-
|
19 |
-
def import_dataset(config: dict):
|
20 |
-
if config.data.dataset_type == 'persistent':
|
21 |
-
from monai.data import PersistentDataset
|
22 |
-
if os.path.exists(config.data.cache_dir):
|
23 |
-
shutil.rmtree(config.data.cache_dir) # rm previous cache DS
|
24 |
-
os.makedirs(config.data.cache_dir, exist_ok = True)
|
25 |
-
Dataset = partial(PersistentDataset, cache_dir = config.data.cache_dir)
|
26 |
-
elif config.data.dataset_type == 'cache':
|
27 |
-
from monai.data import CacheDataset
|
28 |
-
raise NotImplementedError('CacheDataset not yet implemented')
|
29 |
-
else:
|
30 |
-
from monai.data import Dataset
|
31 |
-
return Dataset
|
32 |
-
|
33 |
-
|
34 |
-
class DataLoader(MonaiDataLoader):
|
35 |
-
"overwrite monai DataLoader for enhanced viewing capabilities"
|
36 |
-
|
37 |
-
def show_batch(self,
|
38 |
-
image_key: str='image',
|
39 |
-
label_key: str='label',
|
40 |
-
image_transform=lambda x: x.squeeze().transpose(0,2).flip(-2),
|
41 |
-
label_transform=lambda x: x.squeeze().transpose(0,2).flip(-2)):
|
42 |
-
"""Args:
|
43 |
-
image_key: dict key name for image to view
|
44 |
-
label_key: dict kex name for corresponding label. Can be a tensor or str
|
45 |
-
image_transform: transform input before it is passed to the viewer to ensure
|
46 |
-
ndim of the image is equal to 3 and image is oriented correctly
|
47 |
-
label_transform: transform labels before passed to the viewer, to ensure
|
48 |
-
segmentations masks have same shape and orientations as images. Should be
|
49 |
-
identity function of labels are str.
|
50 |
-
"""
|
51 |
-
from .viewer import ListViewer
|
52 |
-
|
53 |
-
batch = first(self)
|
54 |
-
image = torch.unbind(batch[image_key], 0)
|
55 |
-
label = torch.unbind(batch[label_key], 0)
|
56 |
-
|
57 |
-
ListViewer([image_transform(im) for im in image],
|
58 |
-
[label_transform(im) for im in label]).show()
|
59 |
-
|
60 |
-
# TODO
|
61 |
-
## Work with 3 dataloaders
|
62 |
-
|
63 |
-
def segmentation_dataloaders(config: dict,
|
64 |
-
train: bool = None,
|
65 |
-
valid: bool = None,
|
66 |
-
test: bool = None,
|
67 |
-
):
|
68 |
-
"""Create segmentation dataloaders
|
69 |
-
Args:
|
70 |
-
config: config file
|
71 |
-
train: whether to return a train DataLoader
|
72 |
-
valid: whether to return a valid DataLoader
|
73 |
-
test: whether to return a test DateLoader
|
74 |
-
Args from config:
|
75 |
-
data_dir: base directory for the data
|
76 |
-
csv_name: path to csv file containing filenames and paths
|
77 |
-
image_cols: columns in csv containing path to images
|
78 |
-
label_cols: columns in csv containing path to label files
|
79 |
-
dataset_type: PersistentDataset, CacheDataset and Dataset are supported
|
80 |
-
cache_dir: cache directory to be used by PersistentDataset
|
81 |
-
batch_size: batch size for training. Valid and test are always 1
|
82 |
-
debug: run with reduced number of images
|
83 |
-
Returns:
|
84 |
-
list of:
|
85 |
-
train_loader: DataLoader (optional, if train==True)
|
86 |
-
valid_loader: DataLoader (optional, if valid==True)
|
87 |
-
test_loader: DataLoader (optional, if test==True)
|
88 |
-
"""
|
89 |
-
|
90 |
-
## parse needed rguments from config
|
91 |
-
if train is None: train = config.data.train
|
92 |
-
if valid is None: valid = config.data.valid
|
93 |
-
if test is None: test = config.data.test
|
94 |
-
|
95 |
-
data_dir = config.data.data_dir
|
96 |
-
train_csv = config.data.train_csv
|
97 |
-
valid_csv = config.data.valid_csv
|
98 |
-
test_csv = config.data.test_csv
|
99 |
-
image_cols = config.data.image_cols
|
100 |
-
label_cols = config.data.label_cols
|
101 |
-
dataset_type = config.data.dataset_type
|
102 |
-
cache_dir = config.data.cache_dir
|
103 |
-
batch_size = config.data.batch_size
|
104 |
-
debug = config.debug
|
105 |
-
|
106 |
-
## ---------- data dicts ----------
|
107 |
-
|
108 |
-
# first a global data dict, containing only the filepath from image_cols and label_cols is created. For this,
|
109 |
-
# the dataframe is reduced to only the relevant columns. Then the rows are iterated, converting each row into an
|
110 |
-
# individual dict, as expected by monai
|
111 |
-
|
112 |
-
if not isinstance(image_cols, (tuple, list)): image_cols = [image_cols]
|
113 |
-
if not isinstance(label_cols, (tuple, list)): label_cols = [label_cols]
|
114 |
-
|
115 |
-
train_df = pd.read_csv(train_csv)
|
116 |
-
valid_df = pd.read_csv(valid_csv)
|
117 |
-
test_df = pd.read_csv(test_csv)
|
118 |
-
if debug:
|
119 |
-
train_df = train_df.sample(25)
|
120 |
-
valid_df = valid_df.sample(5)
|
121 |
-
|
122 |
-
train_df['split']='train'
|
123 |
-
valid_df['split']='valid'
|
124 |
-
test_df['split']='test'
|
125 |
-
whole_df = []
|
126 |
-
if train: whole_df += [train_df]
|
127 |
-
if valid: whole_df += [valid_df]
|
128 |
-
if test: whole_df += [test_df]
|
129 |
-
df = pd.concat(whole_df)
|
130 |
-
cols = image_cols + label_cols
|
131 |
-
for col in cols:
|
132 |
-
# create absolute file name from relative fn in df and data_dir
|
133 |
-
df[col] = [os.path.join(data_dir, fn) for fn in df[col]]
|
134 |
-
if not os.path.exists(list(df[col])[0]):
|
135 |
-
raise FileNotFoundError(list(df[col])[0])
|
136 |
-
|
137 |
-
|
138 |
-
data_dict = [dict(row[1]) for row in df[cols].iterrows()]
|
139 |
-
# data_dict is not the correct name, list_of_data_dicts would be more accurate, but also longer.
|
140 |
-
# The data_dict looks like this:
|
141 |
-
# [
|
142 |
-
# {'image_col_1': 'data_dir/path/to/image1',
|
143 |
-
# 'image_col_2': 'data_dir/path/to/image2'
|
144 |
-
# 'label_col_1': 'data_dir/path/to/label1},
|
145 |
-
# {'image_col_1': 'data_dir/path/to/image1',
|
146 |
-
# 'image_col_2': 'data_dir/path/to/image2'
|
147 |
-
# 'label_col_1': 'data_dir/path/to/label1},
|
148 |
-
# ...]
|
149 |
-
# Filename should now be absolute or relative to working directory
|
150 |
-
|
151 |
-
# now we create separate data dicts for train, valid and test data respectively
|
152 |
-
assert train or test or valid, 'No dataset type is specified (train/valid or test)'
|
153 |
-
|
154 |
-
if test:
|
155 |
-
test_files = list(map(data_dict.__getitem__, *np.where(df.split == 'test')))
|
156 |
-
|
157 |
-
if valid:
|
158 |
-
val_files = list(map(data_dict.__getitem__, *np.where(df.split == 'valid')))
|
159 |
-
|
160 |
-
if train:
|
161 |
-
train_files = list(map(data_dict.__getitem__, *np.where(df.split == 'train')))
|
162 |
-
|
163 |
-
# transforms are specified in transforms.py and are just loaded here
|
164 |
-
if train: train_transforms = transforms.get_train_transforms(config)
|
165 |
-
if valid: val_transforms = transforms.get_val_transforms(config)
|
166 |
-
if test: test_transforms = transforms.get_test_transforms(config)
|
167 |
-
|
168 |
-
|
169 |
-
## ---------- construct dataloaders ----------
|
170 |
-
Dataset=import_dataset(config)
|
171 |
-
data_loaders = []
|
172 |
-
if train:
|
173 |
-
train_ds = Dataset(
|
174 |
-
data=train_files,
|
175 |
-
transform=train_transforms
|
176 |
-
)
|
177 |
-
train_loader = DataLoader(
|
178 |
-
train_ds,
|
179 |
-
batch_size=batch_size,
|
180 |
-
num_workers=num_workers(),
|
181 |
-
shuffle=True
|
182 |
-
)
|
183 |
-
data_loaders.append(train_loader)
|
184 |
-
|
185 |
-
if valid:
|
186 |
-
val_ds = Dataset(
|
187 |
-
data=val_files,
|
188 |
-
transform=val_transforms
|
189 |
-
)
|
190 |
-
val_loader = DataLoader(
|
191 |
-
val_ds,
|
192 |
-
batch_size=1,
|
193 |
-
num_workers=num_workers(),
|
194 |
-
shuffle=False
|
195 |
-
)
|
196 |
-
data_loaders.append(val_loader)
|
197 |
-
|
198 |
-
if test:
|
199 |
-
test_ds = Dataset(
|
200 |
-
data=test_files,
|
201 |
-
transform=test_transforms
|
202 |
-
)
|
203 |
-
test_loader = DataLoader(
|
204 |
-
test_ds,
|
205 |
-
batch_size=1,
|
206 |
-
num_workers=num_workers(),
|
207 |
-
shuffle=False
|
208 |
-
)
|
209 |
-
data_loaders.append(test_loader)
|
210 |
-
|
211 |
-
# if only one dataloader is constructed, return only this dataloader else return a named tuple with dataloaders,
|
212 |
-
# so it is clear which DataLoader is train/valid or test
|
213 |
-
|
214 |
-
if len(data_loaders) == 1:
|
215 |
-
return data_loaders[0]
|
216 |
-
else:
|
217 |
-
DataLoaders = namedtuple(
|
218 |
-
'DataLoaders',
|
219 |
-
# create str with specification of loader type if train and test are true but
|
220 |
-
# valid is false string will be 'train test'
|
221 |
-
' '.join(
|
222 |
-
[
|
223 |
-
'train' if train else '',
|
224 |
-
'valid' if valid else '',
|
225 |
-
'test' if test else ''
|
226 |
-
]
|
227 |
-
).strip()
|
228 |
-
)
|
229 |
-
return DataLoaders(*data_loaders)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/loss.py
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
import monai
|
2 |
-
from .utils import load_config
|
3 |
-
|
4 |
-
|
5 |
-
def get_loss(config: dict):
|
6 |
-
"""Create a loss function of `type` with specific keyword arguments from config.
|
7 |
-
Example:
|
8 |
-
|
9 |
-
config.loss
|
10 |
-
>>> {'DiceCELoss': {'include_background': False, 'softmax': True, 'to_onehot_y': True}}
|
11 |
-
|
12 |
-
get_loss(config)
|
13 |
-
>>> DiceCELoss(
|
14 |
-
>>> (dice): DiceLoss()
|
15 |
-
>>> (cross_entropy): CrossEntropyLoss()
|
16 |
-
>>> )
|
17 |
-
|
18 |
-
"""
|
19 |
-
loss_type = list(config.loss.keys())[0]
|
20 |
-
loss_config = config.loss[loss_type]
|
21 |
-
loss_fun = getattr(monai.losses, loss_type)
|
22 |
-
loss = loss_fun(**loss_config)
|
23 |
-
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/model.py
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
# create a standard UNet
|
2 |
-
|
3 |
-
from monai.networks.nets import UNet
|
4 |
-
|
5 |
-
def get_model(config: dict):
|
6 |
-
return UNet(
|
7 |
-
spatial_dims=config.ndim,
|
8 |
-
in_channels=len(config.data.image_cols),
|
9 |
-
out_channels=config.model.out_channels,
|
10 |
-
channels=config.model.channels,
|
11 |
-
strides=config.model.strides,
|
12 |
-
num_res_units=config.model.num_res_units,
|
13 |
-
act=config.model.act,
|
14 |
-
norm=config.model.norm,
|
15 |
-
dropout=config.model.dropout,
|
16 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/optimizer.py
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import monai
|
3 |
-
from .utils import load_config
|
4 |
-
|
5 |
-
|
6 |
-
def get_optimizer(model: torch.nn.Module,
|
7 |
-
config: dict):
|
8 |
-
"""Create an optimizer of `type` with specific keyword arguments from config.
|
9 |
-
Example:
|
10 |
-
|
11 |
-
config.optimizer
|
12 |
-
>>> {'Novograd': {'lr': 0.001, 'weight_decay': 0.01}}
|
13 |
-
|
14 |
-
get_optimizer(model, config)
|
15 |
-
>>> Novograd (
|
16 |
-
>>> Parameter Group 0
|
17 |
-
>>> amsgrad: False
|
18 |
-
>>> betas: (0.9, 0.999)
|
19 |
-
>>> eps: 1e-08
|
20 |
-
>>> grad_averaging: False
|
21 |
-
>>> lr: 0.0001
|
22 |
-
>>> weight_decay: 0.001
|
23 |
-
>>> )
|
24 |
-
|
25 |
-
"""
|
26 |
-
optimizer_type = list(config.optimizer.keys())[0]
|
27 |
-
opt_config = config.optimizer[optimizer_type]
|
28 |
-
if hasattr(torch.optim, optimizer_type):
|
29 |
-
optimizer_fun = getattr(torch.optim, optimizer_type)
|
30 |
-
elif hasattr(monai.optimizers, optimizer_type):
|
31 |
-
optimizer_fun = getattr(monai.optimizers, optimizer_type)
|
32 |
-
else:
|
33 |
-
raise ValueError(f'Optimizer {optimizer_type} not found')
|
34 |
-
optimizer = optimizer_fun(model.parameters(), **opt_config)
|
35 |
-
return optimizer
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/report.py
DELETED
@@ -1,184 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import io
|
3 |
-
import cv2
|
4 |
-
import tqdm
|
5 |
-
import torch
|
6 |
-
import imageio
|
7 |
-
import numpy as np
|
8 |
-
import pandas as pd
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
-
|
11 |
-
class ReportGenerator():
|
12 |
-
"Generate markdown document, summarizing the training"
|
13 |
-
|
14 |
-
def __init__(self, run_id, out_dir=None, log_dir=None):
|
15 |
-
|
16 |
-
self.run_id, self.out_dir, self.log_dir = run_id, out_dir, log_dir
|
17 |
-
|
18 |
-
if log_dir:
|
19 |
-
self.train_logs = pd.read_csv(os.path.join(log_dir, 'train_logs.csv'))
|
20 |
-
self.metric_logs = pd.read_csv(os.path.join(log_dir, 'metric_logs.csv'))
|
21 |
-
if out_dir:
|
22 |
-
self.dice = pd.read_csv(os.path.join(out_dir, 'MeanDice_raw.csv'))
|
23 |
-
self.hausdorf = pd.read_csv(os.path.join(out_dir, 'HausdorffDistance_raw.csv'))
|
24 |
-
self.surface = pd.read_csv(os.path.join(out_dir, 'SurfaceDistance_raw.csv'))
|
25 |
-
|
26 |
-
self.mean_metrics = pd.DataFrame(
|
27 |
-
{"mean_dice" : [round(np.mean(self.dice[col]),3) for col in self.dice if col.startswith('class')],
|
28 |
-
"mean_hausdorf" : [round(np.mean(self.hausdorf[col]),3) for col in self.hausdorf if col.startswith('class')],
|
29 |
-
"mean_surface" : [round(np.mean(self.surface[col]),3) for col in self.surface if col.startswith('class')]
|
30 |
-
}).transpose()
|
31 |
-
|
32 |
-
def generate_report(self, loss_plot=True, metric_plot=True, boxplots=True, animation=True):
|
33 |
-
fn = os.path.join(self.run_id, 'report', 'SegmentationReport.md')
|
34 |
-
os.makedirs(os.path.join(self.run_id, 'report'), exist_ok=True)
|
35 |
-
with open(fn, 'w+') as f:
|
36 |
-
f.write('# Segmentation Report\n\n')
|
37 |
-
|
38 |
-
if loss_plot:
|
39 |
-
fig = self.plot_loss(self.train_logs, self.metric_logs)
|
40 |
-
plt.savefig(os.path.join(self.run_id, 'report', 'loss_and_lr.png'), dpi = 150)
|
41 |
-
|
42 |
-
with open(fn, 'a') as f:
|
43 |
-
f.write('## Loss, LR-Schedule and Key Metric\n')
|
44 |
-
f.write('\n\n')
|
45 |
-
|
46 |
-
if metric_plot:
|
47 |
-
fig = plt.figure("metrics", (18, 6))
|
48 |
-
|
49 |
-
ax = plt.subplot(1, 3, 1)
|
50 |
-
plt.ylim([0,1])
|
51 |
-
plt.title("Mean Dice")
|
52 |
-
plt.xlabel("epoch")
|
53 |
-
plt.plot(self.metric_logs.index, self.metric_logs.MeanDice)
|
54 |
-
|
55 |
-
ax = plt.subplot(1, 3, 2)
|
56 |
-
plt.title("Mean Hausdorff Distance")
|
57 |
-
plt.xlabel("epoch")
|
58 |
-
plt.plot(self.metric_logs.index, self.metric_logs.HausdorffDistance)
|
59 |
-
|
60 |
-
ax = plt.subplot(1, 3, 3)
|
61 |
-
plt.title("Mean Surface Distance")
|
62 |
-
plt.xlabel("epoch")
|
63 |
-
plt.plot(self.metric_logs.index, self.metric_logs.SurfaceDistance)
|
64 |
-
|
65 |
-
plt.savefig(os.path.join(self.run_id, 'report', 'metrics.png'), dpi = 150)
|
66 |
-
fig.clear()
|
67 |
-
plt.close()
|
68 |
-
|
69 |
-
with open(fn, 'a') as f:
|
70 |
-
f.write('## Metrics\n')
|
71 |
-
f.write('\n\n')
|
72 |
-
|
73 |
-
if boxplots:
|
74 |
-
fig = plt.figure("boxplots", (18, 6))
|
75 |
-
|
76 |
-
ax = plt.subplot(1, 3, 1)
|
77 |
-
plt.title("Dice")
|
78 |
-
plt.xlabel("class")
|
79 |
-
plt.boxplot(self.dice[[col for col in self.dice if col.startswith('class')]])
|
80 |
-
|
81 |
-
ax = plt.subplot(1, 3, 2)
|
82 |
-
plt.title("Hausdorff Distance")
|
83 |
-
plt.xlabel("class")
|
84 |
-
plt.boxplot(self.hausdorf[[col for col in self.hausdorf if col.startswith('class')]])
|
85 |
-
|
86 |
-
ax = plt.subplot(1, 3, 3)
|
87 |
-
plt.title("Surface Distance")
|
88 |
-
plt.xlabel("class")
|
89 |
-
plt.boxplot(self.surface[[col for col in self.surface if col.startswith('class')]])
|
90 |
-
|
91 |
-
plt.savefig(os.path.join(self.run_id, 'report', 'boxplots.png'),dpi = 150)
|
92 |
-
|
93 |
-
fig.clear()
|
94 |
-
plt.close()
|
95 |
-
|
96 |
-
with open(fn, 'a') as f:
|
97 |
-
f.write(f"## Individual metrics\n\n")
|
98 |
-
f.write(f"{self.mean_metrics.to_markdown()}\n\n")
|
99 |
-
f.write(f"\n\n")
|
100 |
-
if animation:
|
101 |
-
self.generate_gif()
|
102 |
-
with open(fn, 'a') as f:
|
103 |
-
f.write('## Visualization of progress\n')
|
104 |
-
f.write('\n\n')
|
105 |
-
|
106 |
-
def plot_loss(self, train_logs, metric_logs):
|
107 |
-
iteration = train_logs.iteration/sum(train_logs.epoch == 1)
|
108 |
-
fig = plt.figure("loss and lr", (12, 6))
|
109 |
-
|
110 |
-
y_max = max(metric_logs.eval_loss) + 0.5
|
111 |
-
if y_max > 3: y_max = 3
|
112 |
-
|
113 |
-
ax = plt.subplot(1, 2, 1)
|
114 |
-
plt.ylim([0,y_max])
|
115 |
-
plt.title("Epoch Average Loss")
|
116 |
-
plt.xlabel("epoch")
|
117 |
-
plt.plot(iteration, train_logs.loss)
|
118 |
-
plt.plot(metric_logs.index, metric_logs.eval_loss)
|
119 |
-
|
120 |
-
ax = plt.subplot(1, 2, 2)
|
121 |
-
ax.set_yscale('log')
|
122 |
-
plt.title("LR Schedule")
|
123 |
-
plt.xlabel("epoch")
|
124 |
-
plt.plot(iteration, train_logs.lr)
|
125 |
-
return fig
|
126 |
-
|
127 |
-
def get_arr_from_fig(self, fig, dpi=180):
|
128 |
-
buf = io.BytesIO()
|
129 |
-
fig.savefig(buf, format="png", dpi=dpi)
|
130 |
-
buf.seek(0)
|
131 |
-
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
|
132 |
-
buf.close()
|
133 |
-
img = cv2.imdecode(img_arr, 1)
|
134 |
-
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
135 |
-
return img
|
136 |
-
|
137 |
-
def get_slices(self, im, slices):
|
138 |
-
ims = torch.unbind(im[:, :, slices], -1) # extract n slices
|
139 |
-
ims = [i.transpose(0,1).flip(0) for i in ims] # rotate slices 90 degrees
|
140 |
-
if len(slices) > 4 and len(slices) % 2 == 0:
|
141 |
-
n = len(slices) // 2
|
142 |
-
ims1 = torch.cat(ims[0:n], 1)
|
143 |
-
ims2 = torch.cat(ims[n:], 1)
|
144 |
-
return torch.cat([ims1, ims2], 0)
|
145 |
-
else:
|
146 |
-
return torch.cat(ims, 1) # create tile
|
147 |
-
|
148 |
-
def plot_images(self, fns, slices, cmap='Greys_r', figsize=15, **kwargs):
|
149 |
-
ims = [torch.load(os.path.join(self.out_dir, 'preds', fn)).cpu().argmax(0) for fn in fns]
|
150 |
-
ims = [self.get_slices(im, slices) for im in ims]
|
151 |
-
ims = torch.cat(ims, 0)
|
152 |
-
plt.figure(figsize=(figsize,figsize))
|
153 |
-
plt.imshow(ims, cmap=cmap, **kwargs)
|
154 |
-
plt.axis('off')
|
155 |
-
|
156 |
-
def load_segmentation_image(self, fn):
|
157 |
-
im = torch.load(fn).cpu().unsqueeze(0)
|
158 |
-
im = torch.nn.functional.interpolate(im, (224, 224, 112))
|
159 |
-
im = im.argmax(1).squeeze()
|
160 |
-
im = self.get_slices(im, slices = (40, 48, 56, 74, 82, 90))
|
161 |
-
im = im/im.max() * 255
|
162 |
-
return im
|
163 |
-
|
164 |
-
def generate_gif(self):
|
165 |
-
with imageio.get_writer(
|
166 |
-
os.path.join(self.run_id,'report','progress.gif'),
|
167 |
-
mode='I',
|
168 |
-
fps = max(self.train_logs.epoch) // 10) as writer: # make gif 10 seconds
|
169 |
-
for epoch in tqdm.tqdm(list(self.train_logs.epoch.unique())):
|
170 |
-
seg_fn = os.path.join(self.out_dir, 'preds', f"pred_epoch_{epoch}.pt")
|
171 |
-
if os.path.exists(seg_fn): im = self.load_segmentation_image(seg_fn)
|
172 |
-
|
173 |
-
plt_train_logs = self.train_logs[self.train_logs.epoch <= epoch]
|
174 |
-
loss_plt = self.plot_loss(plt_train_logs, self.metric_logs[:epoch])
|
175 |
-
loss_fig = self.get_arr_from_fig(loss_plt)[:,:,0]
|
176 |
-
|
177 |
-
new_shape = im.shape[1], int(loss_fig.shape[0] / loss_fig.shape[1] * im.shape[1])
|
178 |
-
loss_fig = cv2.resize(loss_fig, (im.shape[1], im.shape[0]))
|
179 |
-
|
180 |
-
images = torch.cat([im, torch.tensor(loss_fig)], 0).numpy().astype(np.uint8)
|
181 |
-
writer.append_data(images)
|
182 |
-
|
183 |
-
loss_plt.clear()
|
184 |
-
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/train.py
DELETED
@@ -1,482 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import yaml
|
3 |
-
import munch
|
4 |
-
import torch
|
5 |
-
import ignite
|
6 |
-
import monai
|
7 |
-
import shutil
|
8 |
-
import pandas as pd
|
9 |
-
|
10 |
-
from typing import Union, List, Callable
|
11 |
-
from ignite.contrib.handlers.tqdm_logger import ProgressBar
|
12 |
-
from monai.handlers import (
|
13 |
-
CheckpointSaver,
|
14 |
-
StatsHandler,
|
15 |
-
TensorBoardStatsHandler,
|
16 |
-
TensorBoardImageHandler,
|
17 |
-
ValidationHandler,
|
18 |
-
from_engine,
|
19 |
-
MeanDice,
|
20 |
-
EarlyStopHandler,
|
21 |
-
MetricLogger,
|
22 |
-
MetricsSaver
|
23 |
-
)
|
24 |
-
|
25 |
-
from .data import segmentation_dataloaders
|
26 |
-
from .model import get_model
|
27 |
-
from .optimizer import get_optimizer
|
28 |
-
from .loss import get_loss
|
29 |
-
from .transforms import get_val_post_transforms
|
30 |
-
from .utils import USE_AMP
|
31 |
-
|
32 |
-
def loss_logger(engine):
|
33 |
-
"write loss and lr of each iteration/epoch to file"
|
34 |
-
iteration=engine.state.iteration
|
35 |
-
epoch=engine.state.epoch
|
36 |
-
loss=[o['loss'] for o in engine.state.output]
|
37 |
-
loss=sum(loss)/len(loss)
|
38 |
-
lr=engine.optimizer.param_groups[0]['lr']
|
39 |
-
log_file=os.path.join(engine.config.log_dir, 'train_logs.csv')
|
40 |
-
if not os.path.exists(log_file):
|
41 |
-
with open(log_file, 'w+') as f:
|
42 |
-
f.write('iteration,epoch,loss,lr\n')
|
43 |
-
with open(log_file, 'a') as f:
|
44 |
-
f.write(f'{iteration},{epoch},{loss},{lr}\n')
|
45 |
-
|
46 |
-
def metric_logger(engine):
|
47 |
-
"write `metrics` after each epoch to file"
|
48 |
-
if engine.state.epoch > 1: # only key metric is calcualted in 1st epoch, needs fix
|
49 |
-
metric_names=[k for k in engine.state.metrics.keys()]
|
50 |
-
metrics=[str(engine.state.metrics[mn]) for mn in metric_names]
|
51 |
-
log_file=os.path.join(engine.config.log_dir, 'metric_logs.csv')
|
52 |
-
if not os.path.exists(log_file):
|
53 |
-
with open(log_file, 'w+') as f:
|
54 |
-
f.write(','.join(metric_names) + '\n')
|
55 |
-
with open(log_file, 'a') as f:
|
56 |
-
f.write(','.join(metrics) + '\n')
|
57 |
-
|
58 |
-
def pred_logger(engine):
|
59 |
-
"save `pred` each time metric improves"
|
60 |
-
epoch=engine.state.epoch
|
61 |
-
root = os.path.join(engine.config.out_dir, 'preds')
|
62 |
-
if not os.path.exists(root):
|
63 |
-
os.makedirs(root)
|
64 |
-
torch.save(
|
65 |
-
engine.state.output[0]['label'],
|
66 |
-
os.path.join(root, f'label.pt')
|
67 |
-
)
|
68 |
-
torch.save(
|
69 |
-
engine.state.output[0]['image'],
|
70 |
-
os.path.join(root, f'image.pt')
|
71 |
-
)
|
72 |
-
|
73 |
-
if epoch==engine.state.best_metric_epoch:
|
74 |
-
torch.save(
|
75 |
-
engine.state.output[0]['pred'],
|
76 |
-
os.path.join(root, f'pred_epoch_{epoch}.pt')
|
77 |
-
)
|
78 |
-
|
79 |
-
|
80 |
-
def get_val_handlers(
|
81 |
-
network: torch.nn.Module,
|
82 |
-
config: dict
|
83 |
-
) -> list:
|
84 |
-
"""Create default handlers for model validation
|
85 |
-
Args:
|
86 |
-
network:
|
87 |
-
nn.Module subclass, the model to train
|
88 |
-
|
89 |
-
Returns:
|
90 |
-
a list of default handlers for validation: [
|
91 |
-
StatsHandler:
|
92 |
-
???
|
93 |
-
TensorBoardStatsHandler:
|
94 |
-
Save loss from validation to `config.log_dir`, allow logging with TensorBoard
|
95 |
-
CheckpointSaver:
|
96 |
-
Save best model to `config.model_dir`
|
97 |
-
]
|
98 |
-
"""
|
99 |
-
|
100 |
-
val_handlers=[
|
101 |
-
StatsHandler(
|
102 |
-
tag_name="metric_logger",
|
103 |
-
epoch_print_logger=metric_logger,
|
104 |
-
output_transform=lambda x: None
|
105 |
-
),
|
106 |
-
StatsHandler(
|
107 |
-
tag_name="pred_logger",
|
108 |
-
epoch_print_logger=pred_logger,
|
109 |
-
output_transform=lambda x: None
|
110 |
-
),
|
111 |
-
TensorBoardStatsHandler(
|
112 |
-
log_dir=config.log_dir,
|
113 |
-
# tag_name="val_mean_dice",
|
114 |
-
output_transform=lambda x: None
|
115 |
-
),
|
116 |
-
TensorBoardImageHandler(
|
117 |
-
log_dir=config.log_dir,
|
118 |
-
batch_transform=from_engine(["image", "label"]),
|
119 |
-
output_transform=from_engine(["pred"]),
|
120 |
-
),
|
121 |
-
CheckpointSaver(
|
122 |
-
save_dir=config.model_dir,
|
123 |
-
save_dict={f"network_{config.run_id}": network},
|
124 |
-
save_key_metric=True
|
125 |
-
),
|
126 |
-
|
127 |
-
]
|
128 |
-
|
129 |
-
return val_handlers
|
130 |
-
|
131 |
-
|
132 |
-
def get_train_handlers(
|
133 |
-
evaluator: monai.engines.SupervisedEvaluator,
|
134 |
-
config: dict
|
135 |
-
) -> list:
|
136 |
-
"""Create default handlers for model training
|
137 |
-
Args:
|
138 |
-
evaluator: an engine of type `monai.engines.SupervisedEvaluator` for evaluations
|
139 |
-
every epoch
|
140 |
-
|
141 |
-
Returns:
|
142 |
-
list of default handlers for training: [
|
143 |
-
ValidationHandler:
|
144 |
-
Allows model validation every epoch
|
145 |
-
StatsHandler:
|
146 |
-
???
|
147 |
-
TensorBoardStatsHandler:
|
148 |
-
Save loss from validation to `config.log_dir`, allow logging with TensorBoard
|
149 |
-
]
|
150 |
-
"""
|
151 |
-
|
152 |
-
train_handlers=[
|
153 |
-
ValidationHandler(
|
154 |
-
validator=evaluator,
|
155 |
-
interval=1,
|
156 |
-
epoch_level=True
|
157 |
-
),
|
158 |
-
StatsHandler(
|
159 |
-
tag_name="train_loss",
|
160 |
-
output_transform=from_engine(
|
161 |
-
["loss"],
|
162 |
-
first=True
|
163 |
-
)
|
164 |
-
),
|
165 |
-
StatsHandler(
|
166 |
-
tag_name='loss_logger',
|
167 |
-
iteration_print_logger=loss_logger
|
168 |
-
),
|
169 |
-
TensorBoardStatsHandler(
|
170 |
-
log_dir=config.log_dir,
|
171 |
-
tag_name="train_loss",
|
172 |
-
output_transform=from_engine(
|
173 |
-
["loss"],
|
174 |
-
first=True
|
175 |
-
),
|
176 |
-
)
|
177 |
-
]
|
178 |
-
|
179 |
-
return train_handlers
|
180 |
-
|
181 |
-
def get_evaluator(
|
182 |
-
config: dict,
|
183 |
-
device: torch.device ,
|
184 |
-
network: torch.nn.Module,
|
185 |
-
val_data_loader: monai.data.dataloader.DataLoader,
|
186 |
-
val_post_transforms: monai.transforms.compose.Compose,
|
187 |
-
val_handlers: Union[Callable, List]=get_val_handlers
|
188 |
-
) -> monai.engines.SupervisedEvaluator:
|
189 |
-
|
190 |
-
"""Create default evaluator for training of a segmentation model
|
191 |
-
Args:
|
192 |
-
device:
|
193 |
-
torch.cuda.device for model and engine
|
194 |
-
network:
|
195 |
-
nn.Module subclass, the model to train
|
196 |
-
val_data_loader:
|
197 |
-
Validation data loader, `monai.data.dataloader.DataLoader` subclass
|
198 |
-
val_post_transforms:
|
199 |
-
function to create transforms OR composed transforms
|
200 |
-
val_handlers:
|
201 |
-
function to create handerls OR List of handlers
|
202 |
-
|
203 |
-
Returns:
|
204 |
-
default evaluator for segmentation of type `monai.engines.SupervisedEvaluator`
|
205 |
-
"""
|
206 |
-
|
207 |
-
if callable(val_handlers): val_handlers=val_handlers()
|
208 |
-
|
209 |
-
evaluator=monai.engines.SupervisedEvaluator(
|
210 |
-
device=device,
|
211 |
-
val_data_loader=val_data_loader,
|
212 |
-
network=network,
|
213 |
-
inferer=monai.inferers.SlidingWindowInferer(
|
214 |
-
roi_size=(96, 96, 96),
|
215 |
-
sw_batch_size=4,
|
216 |
-
overlap=0.5
|
217 |
-
),
|
218 |
-
postprocessing=val_post_transforms,
|
219 |
-
key_val_metric={
|
220 |
-
"val_mean_dice": MeanDice(
|
221 |
-
include_background=False,
|
222 |
-
output_transform=from_engine(
|
223 |
-
["pred", "label"]
|
224 |
-
)
|
225 |
-
)
|
226 |
-
},
|
227 |
-
val_handlers=val_handlers,
|
228 |
-
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
|
229 |
-
amp=USE_AMP,
|
230 |
-
)
|
231 |
-
evaluator.config=config
|
232 |
-
return evaluator
|
233 |
-
|
234 |
-
|
235 |
-
class SegmentationTrainer(monai.engines.SupervisedTrainer):
|
236 |
-
"Default Trainer für supervised segmentation task"
|
237 |
-
def __init__(self,
|
238 |
-
config: dict,
|
239 |
-
progress_bar: bool=True,
|
240 |
-
early_stopping: bool=True,
|
241 |
-
metrics: list=["MeanDice", "HausdorffDistance", "SurfaceDistance"],
|
242 |
-
save_latest_metrics: bool=True
|
243 |
-
):
|
244 |
-
self.config=config
|
245 |
-
self._prepare_dirs()
|
246 |
-
self.config.device=torch.device(self.config.device)
|
247 |
-
|
248 |
-
train_loader, val_loader=segmentation_dataloaders(
|
249 |
-
config=config,
|
250 |
-
train=True,
|
251 |
-
valid=True,
|
252 |
-
test=False
|
253 |
-
)
|
254 |
-
network=get_model(config=config).to(config.device)
|
255 |
-
optimizer=get_optimizer(
|
256 |
-
network,
|
257 |
-
config=config
|
258 |
-
)
|
259 |
-
loss_fn=get_loss(config=config)
|
260 |
-
val_post_transforms=get_val_post_transforms(config=config)
|
261 |
-
val_handlers=get_val_handlers(
|
262 |
-
network,
|
263 |
-
config=config
|
264 |
-
)
|
265 |
-
|
266 |
-
self.evaluator=get_evaluator(
|
267 |
-
config=config,
|
268 |
-
device=config.device,
|
269 |
-
network=network,
|
270 |
-
val_data_loader=val_loader,
|
271 |
-
val_post_transforms=val_post_transforms,
|
272 |
-
val_handlers=val_handlers,
|
273 |
-
|
274 |
-
)
|
275 |
-
train_handlers=get_train_handlers(
|
276 |
-
self.evaluator,
|
277 |
-
config=config
|
278 |
-
)
|
279 |
-
|
280 |
-
super().__init__(
|
281 |
-
device=config.device,
|
282 |
-
max_epochs=self.config.training.max_epochs,
|
283 |
-
train_data_loader=train_loader,
|
284 |
-
network=network,
|
285 |
-
optimizer=optimizer,
|
286 |
-
loss_function=loss_fn,
|
287 |
-
inferer=monai.inferers.SimpleInferer(),
|
288 |
-
train_handlers=train_handlers,
|
289 |
-
amp=USE_AMP,
|
290 |
-
)
|
291 |
-
|
292 |
-
if early_stopping: self._add_early_stopping()
|
293 |
-
if progress_bar: self._add_progress_bars()
|
294 |
-
|
295 |
-
self.schedulers=[]
|
296 |
-
# add different metrics dynamically
|
297 |
-
for m in metrics:
|
298 |
-
getattr(monai.handlers, m)(
|
299 |
-
include_background=False,
|
300 |
-
reduction="mean",
|
301 |
-
output_transform=from_engine(
|
302 |
-
["pred", "label"]
|
303 |
-
)
|
304 |
-
).attach(self.evaluator, m)
|
305 |
-
|
306 |
-
self._add_metrics_logger()
|
307 |
-
# add eval loss to metrics
|
308 |
-
self._add_eval_loss()
|
309 |
-
|
310 |
-
if save_latest_metrics: self._add_metrics_saver()
|
311 |
-
|
312 |
-
|
313 |
-
def _prepare_dirs(self)->None:
|
314 |
-
# create run_id, copy config file for reproducibility
|
315 |
-
os.makedirs(self.config.run_id, exist_ok=True)
|
316 |
-
with open(
|
317 |
-
os.path.join(
|
318 |
-
self.config.run_id,
|
319 |
-
'config.yaml'
|
320 |
-
), 'w+') as f:
|
321 |
-
f.write(yaml.safe_dump(self.config))
|
322 |
-
|
323 |
-
# delete old log_dir
|
324 |
-
if os.path.exists(self.config.log_dir):
|
325 |
-
shutil.rmtree(self.config.log_dir)
|
326 |
-
|
327 |
-
def _add_early_stopping(self) -> None:
|
328 |
-
early_stopping=EarlyStopHandler(
|
329 |
-
patience=self.config.training.early_stopping_patience,
|
330 |
-
min_delta=1e-4,
|
331 |
-
score_function=lambda x: x.state.metrics[x.state.key_metric_name],
|
332 |
-
trainer=self
|
333 |
-
)
|
334 |
-
self.evaluator.add_event_handler(
|
335 |
-
ignite.engine.Events.COMPLETED,
|
336 |
-
early_stopping
|
337 |
-
)
|
338 |
-
|
339 |
-
def _add_metrics_logger(self) -> None:
|
340 |
-
self.metric_logger=MetricLogger(
|
341 |
-
evaluator=self.evaluator
|
342 |
-
)
|
343 |
-
self.metric_logger.attach(self)
|
344 |
-
|
345 |
-
def _add_progress_bars(self) -> None:
|
346 |
-
trainer_pbar=ProgressBar()
|
347 |
-
evaluator_pbar=ProgressBar(
|
348 |
-
colour='green'
|
349 |
-
)
|
350 |
-
trainer_pbar.attach(
|
351 |
-
self,
|
352 |
-
output_transform=lambda output:{
|
353 |
-
'loss': torch.tensor(
|
354 |
-
[x['loss'] for x in output]
|
355 |
-
).mean()
|
356 |
-
}
|
357 |
-
)
|
358 |
-
evaluator_pbar.attach(self.evaluator)
|
359 |
-
|
360 |
-
def _add_metrics_saver(self) -> None:
|
361 |
-
metric_saver=MetricsSaver(
|
362 |
-
save_dir=self.config.out_dir,
|
363 |
-
metric_details='*',
|
364 |
-
batch_transform=self._get_meta_dict,
|
365 |
-
delimiter=','
|
366 |
-
)
|
367 |
-
metric_saver.attach(self.evaluator)
|
368 |
-
|
369 |
-
def _add_eval_loss(self)->None:
|
370 |
-
# TODO improve by adding this to val handlers
|
371 |
-
eval_loss_handler=ignite.metrics.Loss(
|
372 |
-
loss_fn=self.loss_function,
|
373 |
-
output_transform=lambda output: (
|
374 |
-
output[0]['pred'].unsqueeze(0), # add batch dim
|
375 |
-
output[0]['label'].argmax(0, keepdim=True).unsqueeze(0) # reverse one-hot, add batch dim
|
376 |
-
)
|
377 |
-
)
|
378 |
-
eval_loss_handler.attach(self.evaluator, 'eval_loss')
|
379 |
-
|
380 |
-
def _get_meta_dict(self, batch) -> list:
|
381 |
-
"Get dict of metadata from engine. Needed as `batch_transform`"
|
382 |
-
image_cols=self.config.data.image_cols
|
383 |
-
image_name=image_cols[0] if isinstance(image_cols, list) else image_cols
|
384 |
-
key=f'{image_name}_meta_dict'
|
385 |
-
return [item[key] for item in batch]
|
386 |
-
|
387 |
-
def load_checkpoint(self, checkpoint=None):
|
388 |
-
if not checkpoint:
|
389 |
-
# get name of last checkpoint
|
390 |
-
checkpoint = os.path.join(
|
391 |
-
self.config.model_dir,
|
392 |
-
f"network_{self.config.run_id}_key_metric={self.evaluator.state.best_metric:.4f}.pt"
|
393 |
-
)
|
394 |
-
self.network.load_state_dict(
|
395 |
-
torch.load(checkpoint)
|
396 |
-
)
|
397 |
-
|
398 |
-
def run(self, try_resume_from_checkpoint=True) -> None:
|
399 |
-
"""Run training, if `try_resume_from_checkpoint` tries to
|
400 |
-
load previous checkpoint stored at `self.config.model_dir`
|
401 |
-
"""
|
402 |
-
|
403 |
-
if try_resume_from_checkpoint:
|
404 |
-
checkpoints = [
|
405 |
-
os.path.join(
|
406 |
-
self.config.model_dir,
|
407 |
-
checkpoint_name
|
408 |
-
) for checkpoint_name in os.listdir(
|
409 |
-
self.config.model_dir
|
410 |
-
) if self.config.run_id in checkpoint_name
|
411 |
-
]
|
412 |
-
try:
|
413 |
-
checkpoint = sorted(checkpoints)[-1]
|
414 |
-
self.load_checkpoint(checkpoint)
|
415 |
-
print(f"resuming from previous checkpoint at {checkpoint}")
|
416 |
-
except: pass # train from scratch
|
417 |
-
|
418 |
-
# train the model
|
419 |
-
super().run()
|
420 |
-
|
421 |
-
# make metrics and losses more accessible
|
422 |
-
self.loss={
|
423 |
-
"iter": [_iter for _iter, _ in self.metric_logger.loss],
|
424 |
-
"loss": [_loss for _, _loss in self.metric_logger.loss],
|
425 |
-
"epoch": [_iter // self.state.epoch_length for _iter, _ in self.metric_logger.loss]
|
426 |
-
}
|
427 |
-
|
428 |
-
self.metrics={
|
429 |
-
k: [item[1] for item in self.metric_logger.metrics[k]] for k in
|
430 |
-
self.evaluator.state.metric_details.keys()
|
431 |
-
}
|
432 |
-
# pd.DataFrame(self.metrics).to_csv(f"{self.config.out_dir}/metric_logs.csv")
|
433 |
-
# pd.DataFrame(self.loss).to_csv(f"{self.config.out_dir}/loss_logs.csv")
|
434 |
-
|
435 |
-
def fit_one_cycle(self, try_resume_from_checkpoint=True) -> None:
|
436 |
-
"Run training using one-cycle-policy"
|
437 |
-
assert "FitOneCycle" not in self.schedulers, "FitOneCycle already added"
|
438 |
-
fit_one_cycle=monai.handlers.LrScheduleHandler(
|
439 |
-
torch.optim.lr_scheduler.OneCycleLR(
|
440 |
-
optimizer=self.optimizer,
|
441 |
-
max_lr=self.optimizer.param_groups[0]['lr'],
|
442 |
-
steps_per_epoch=self.state.epoch_length,
|
443 |
-
epochs=self.state.max_epochs
|
444 |
-
),
|
445 |
-
epoch_level=False,
|
446 |
-
name="FitOneCycle"
|
447 |
-
)
|
448 |
-
fit_one_cycle.attach(self)
|
449 |
-
self.schedulers += ["FitOneCycle"]
|
450 |
-
|
451 |
-
def reduce_lr_on_plateau(self,
|
452 |
-
try_resume_from_checkpoint=True,
|
453 |
-
factor=0.1,
|
454 |
-
patience=10,
|
455 |
-
min_lr=1e-10,
|
456 |
-
verbose=True) -> None:
|
457 |
-
"Reduce learning rate by `factor` every `patience` epochs if kex_metric does not improve"
|
458 |
-
assert "ReduceLROnPlateau" not in self.schedulers, "ReduceLROnPlateau already added"
|
459 |
-
reduce_lr_on_plateau=monai.handlers.LrScheduleHandler(
|
460 |
-
torch.optim.lr_scheduler.ReduceLROnPlateau(
|
461 |
-
optimizer=self.optimizer,
|
462 |
-
factor=factor,
|
463 |
-
patience=patience,
|
464 |
-
min_lr=min_lr,
|
465 |
-
verbose=verbose
|
466 |
-
),
|
467 |
-
print_lr=True,
|
468 |
-
name='ReduceLROnPlateau',
|
469 |
-
epoch_level=True,
|
470 |
-
step_transform=lambda engine: engine.state.metrics[engine.state.key_metric_name],
|
471 |
-
)
|
472 |
-
reduce_lr_on_plateau.attach(self.evaluator)
|
473 |
-
self.schedulers += ["ReduceLROnPlateau"]
|
474 |
-
|
475 |
-
def evaluate(self, checkpoint=None, dataloader=None):
|
476 |
-
"Run evaluation with best saved checkpoint"
|
477 |
-
self.load_checkpoint(checkpoint)
|
478 |
-
if dataloader:
|
479 |
-
self.evaluator.set_data(dataloader)
|
480 |
-
self.evaluator.state.epoch_length=len(dataloader)
|
481 |
-
self.evaluator.run()
|
482 |
-
print(f"metrics saved to {self.config.out_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/transforms.py
DELETED
@@ -1,363 +0,0 @@
|
|
1 |
-
# create transforms for training, validation and test dataset
|
2 |
-
|
3 |
-
## TODO: Make Transforms more dynamic by directly building from config args
|
4 |
-
## Maybe like this
|
5 |
-
## TFM_NAME=config.transforms.keys()[0]
|
6 |
-
## tfm_fun=getattr(monai.transforms, TFM_NAME)
|
7 |
-
## tmfs+=[tfms_fun(keys=image+cols, **config.transforms[TFM_NAME], prob=prob, mode=mode)
|
8 |
-
|
9 |
-
|
10 |
-
## ---------- imports ----------
|
11 |
-
import os
|
12 |
-
# only import of base transforms, others are imported as needed
|
13 |
-
from monai.utils.enums import CommonKeys
|
14 |
-
from monai.transforms import (
|
15 |
-
Activationsd,
|
16 |
-
AsDiscreted,
|
17 |
-
Compose,
|
18 |
-
ConcatItemsd,
|
19 |
-
KeepLargestConnectedComponentd,
|
20 |
-
LoadImaged,
|
21 |
-
EnsureChannelFirstd,
|
22 |
-
EnsureTyped,
|
23 |
-
SaveImaged,
|
24 |
-
ScaleIntensityd,
|
25 |
-
NormalizeIntensityd
|
26 |
-
)
|
27 |
-
# images should be interploated with `bilinear` but masks with `nearest`
|
28 |
-
|
29 |
-
## ---------- base transforms ----------
|
30 |
-
# applied everytime
|
31 |
-
def get_base_transforms(
|
32 |
-
config: dict,
|
33 |
-
minv: int=0,
|
34 |
-
maxv: int=1
|
35 |
-
)->list:
|
36 |
-
|
37 |
-
tfms=[]
|
38 |
-
tfms+=[LoadImaged(keys=config.data.image_cols+config.data.label_cols)]
|
39 |
-
tfms+=[EnsureChannelFirstd(keys=config.data.image_cols+config.data.label_cols)]
|
40 |
-
if config.transforms.spacing:
|
41 |
-
from monai.transforms import Spacingd
|
42 |
-
tfms+=[
|
43 |
-
Spacingd(
|
44 |
-
keys=config.data.image_cols+config.data.label_cols,
|
45 |
-
pixdim=config.transforms.spacing,
|
46 |
-
mode=config.transforms.mode
|
47 |
-
)
|
48 |
-
]
|
49 |
-
if config.transforms.orientation:
|
50 |
-
from monai.transforms import Orientationd
|
51 |
-
tfms+=[
|
52 |
-
Orientationd(
|
53 |
-
keys=config.data.image_cols+config.data.label_cols,
|
54 |
-
axcodes=config.transforms.orientation
|
55 |
-
)
|
56 |
-
]
|
57 |
-
tfms+=[
|
58 |
-
ScaleIntensityd(
|
59 |
-
keys=config.data.image_cols,
|
60 |
-
minv=minv,
|
61 |
-
maxv=maxv
|
62 |
-
)
|
63 |
-
]
|
64 |
-
tfms+=[NormalizeIntensityd(keys=config.data.image_cols)]
|
65 |
-
return tfms
|
66 |
-
|
67 |
-
## ---------- train transforms ----------
|
68 |
-
|
69 |
-
def get_train_transforms(config: dict):
|
70 |
-
tfms=get_base_transforms(config=config)
|
71 |
-
|
72 |
-
# ---------- specific transforms for mri ----------
|
73 |
-
if 'rand_bias_field' in config.transforms.keys():
|
74 |
-
from monai.transforms import RandBiasFieldd
|
75 |
-
args=config.transforms.rand_bias_field
|
76 |
-
tfms+=[
|
77 |
-
RandBiasFieldd(
|
78 |
-
keys=config.data.image_cols,
|
79 |
-
degree=args['degree'],
|
80 |
-
coeff_range=args['coeff_range'],
|
81 |
-
prob=config.transforms.prob
|
82 |
-
)
|
83 |
-
]
|
84 |
-
|
85 |
-
if 'rand_gaussian_smooth' in config.transforms.keys():
|
86 |
-
from monai.transforms import RandGaussianSmoothd
|
87 |
-
args=config.transforms.rand_gaussian_smooth
|
88 |
-
tfms+=[
|
89 |
-
RandGaussianSmoothd(
|
90 |
-
keys=config.data.image_cols,
|
91 |
-
sigma_x=args['sigma_x'],
|
92 |
-
sigma_y=args['sigma_y'],
|
93 |
-
sigma_z=args['sigma_z'],
|
94 |
-
prob=config.transforms.prob
|
95 |
-
)
|
96 |
-
]
|
97 |
-
|
98 |
-
if 'rand_gibbs_nose' in config.transforms.keys():
|
99 |
-
from monai.transforms import RandGibbsNoised
|
100 |
-
args=config.transforms.rand_gibbs_nose
|
101 |
-
tfms+=[
|
102 |
-
RandGibbsNoised(
|
103 |
-
keys=config.data.image_cols,
|
104 |
-
alpha=args['alpha'],
|
105 |
-
prob=config.transforms.prob
|
106 |
-
)
|
107 |
-
]
|
108 |
-
|
109 |
-
# ---------- affine transforms ----------
|
110 |
-
|
111 |
-
if 'rand_affine' in config.transforms.keys():
|
112 |
-
from monai.transforms import RandAffined
|
113 |
-
args=config.transforms.rand_affine
|
114 |
-
tfms+=[
|
115 |
-
RandAffined(
|
116 |
-
keys=config.data.image_cols+config.data.label_cols,
|
117 |
-
rotate_range=args['rotate_range'],
|
118 |
-
shear_range=args['shear_range'],
|
119 |
-
translate_range=args['translate_range'],
|
120 |
-
mode=config.transforms.mode,
|
121 |
-
prob=config.transforms.prob
|
122 |
-
)
|
123 |
-
]
|
124 |
-
|
125 |
-
if 'rand_rotate90' in config.transforms.keys():
|
126 |
-
from monai.transforms import RandRotate90d
|
127 |
-
args=config.transforms.rand_rotate90
|
128 |
-
tfms+=[
|
129 |
-
RandRotate90d(
|
130 |
-
keys=config.data.image_cols+config.data.label_cols,
|
131 |
-
spatial_axes=args['spatial_axes'],
|
132 |
-
prob=config.transforms.prob
|
133 |
-
)
|
134 |
-
]
|
135 |
-
|
136 |
-
if 'rand_rotate' in config.transforms.keys():
|
137 |
-
from monai.transforms import RandRotated
|
138 |
-
args=config.transforms.rand_rotate
|
139 |
-
tfms+=[
|
140 |
-
RandRotated(
|
141 |
-
keys=config.data.image_cols+config.data.label_cols,
|
142 |
-
range_x=args['range_x'],
|
143 |
-
range_y=args['range_y'],
|
144 |
-
range_z=args['range_z'],
|
145 |
-
mode=config.transforms.mode,
|
146 |
-
prob=config.transforms.prob
|
147 |
-
)
|
148 |
-
]
|
149 |
-
|
150 |
-
if 'rand_elastic' in config.transforms.keys():
|
151 |
-
if config['ndim'] == 3:
|
152 |
-
from monai.transforms import Rand3DElasticd as RandElasticd
|
153 |
-
elif config['ndim'] == 2:
|
154 |
-
from monai.transforms import Rand2DElasticd as RandElasticd
|
155 |
-
args=config.transforms.rand_elastic
|
156 |
-
tfms+=[
|
157 |
-
RandElasticd(
|
158 |
-
keys=config.data.image_cols+config.data.label_cols,
|
159 |
-
sigma_range=args['sigma_range'],
|
160 |
-
magnitude_range=args['magnitude_range'],
|
161 |
-
rotate_range=args['rotate_range'],
|
162 |
-
shear_range=args['shear_range'],
|
163 |
-
translate_range=args['translate_range'],
|
164 |
-
mode=config.transforms.mode,
|
165 |
-
prob=config.transforms.prob
|
166 |
-
)
|
167 |
-
]
|
168 |
-
|
169 |
-
if 'rand_zoom' in config.transforms.keys():
|
170 |
-
from monai.transforms import RandZoomd
|
171 |
-
args=config.transforms.rand_zoom
|
172 |
-
tfms+=[
|
173 |
-
RandZoomd(
|
174 |
-
keys=config.data.image_cols+config.data.label_cols,
|
175 |
-
min_zoom=args['min'],
|
176 |
-
max_zoom=args['max'],
|
177 |
-
mode=['area' if x == 'bilinear' else x for x in config.transforms.mode],
|
178 |
-
prob=config.transforms.prob
|
179 |
-
)
|
180 |
-
]
|
181 |
-
|
182 |
-
# ---------- random cropping, very effective for large images ----------
|
183 |
-
# RandCropByPosNegLabeld is not advisable for data with missing lables
|
184 |
-
# e.g., segmentation of carcinomas which are not present on all images
|
185 |
-
# thus fallback to RandSpatialCropSamplesd. Completly replacing Cropping
|
186 |
-
# by just resizing could be discussed, but I believe it is not beneficial
|
187 |
-
# For the first version, this is an ungly hack. For the second version,
|
188 |
-
# a better verion for transforms should be written.
|
189 |
-
|
190 |
-
if 'rand_crop_pos_neg_label' in config.transforms.keys():
|
191 |
-
from monai.transforms import RandCropByPosNegLabeld
|
192 |
-
args=config.transforms.rand_crop_pos_neg_label
|
193 |
-
tfms+=[
|
194 |
-
RandCropByPosNegLabeld(
|
195 |
-
keys=config.data.image_cols+config.data.label_cols,
|
196 |
-
label_key=config.data.label_cols[0],
|
197 |
-
spatial_size=args['spatial_size'],
|
198 |
-
pos=args['pos'],
|
199 |
-
neg=args['neg'],
|
200 |
-
num_samples=args['num_samples'],
|
201 |
-
image_key=config.data.image_cols[0],
|
202 |
-
image_threshold=0,
|
203 |
-
)
|
204 |
-
]
|
205 |
-
|
206 |
-
elif 'rand_spatial_crop_samples' in config.transforms.keys():
|
207 |
-
from monai.transforms import RandSpatialCropSamplesd
|
208 |
-
args=config.transforms.rand_spatial_crop_samples
|
209 |
-
tfms+=[
|
210 |
-
RandSpatialCropSamplesd(
|
211 |
-
keys=config.data.image_cols+config.data.label_cols,
|
212 |
-
roi_size=args['roi_size'],
|
213 |
-
random_size=False,
|
214 |
-
num_samples=args['num_samples'],
|
215 |
-
)
|
216 |
-
]
|
217 |
-
|
218 |
-
else:
|
219 |
-
raise ValueError('Either `rand_crop_pos_neg_label` or `rand_spatial_crop_samples` '\
|
220 |
-
'need to be specified')
|
221 |
-
|
222 |
-
# ---------- intensity transforms ----------
|
223 |
-
|
224 |
-
if 'gaussian_noise' in config.transforms.keys():
|
225 |
-
from monai.transforms import RandGaussianNoised
|
226 |
-
args=config.transforms.gaussian_noise
|
227 |
-
tfms+=[
|
228 |
-
RandGaussianNoised(
|
229 |
-
keys=config.data.image_cols,
|
230 |
-
mean=args['mean'],
|
231 |
-
std=args['std'],
|
232 |
-
prob=config.transforms.prob
|
233 |
-
)
|
234 |
-
]
|
235 |
-
|
236 |
-
if 'shift_intensity' in config.transforms.keys():
|
237 |
-
from monai.transforms import RandShiftIntensityd
|
238 |
-
args=config.transforms.shift_intensity
|
239 |
-
tfms+=[
|
240 |
-
RandShiftIntensityd(
|
241 |
-
keys=config.data.image_cols,
|
242 |
-
offsets=args['offsets'],
|
243 |
-
prob=config.transforms.prob
|
244 |
-
)
|
245 |
-
]
|
246 |
-
|
247 |
-
if 'gaussian_sharpen' in config.transforms.keys():
|
248 |
-
from monai.transforms import RandGaussianSharpend
|
249 |
-
args=config.transforms.gaussian_sharpen
|
250 |
-
tfms+=[
|
251 |
-
RandGaussianSharpend(
|
252 |
-
keys=config.data.image_cols,
|
253 |
-
sigma1_x=args['sigma1_x'],
|
254 |
-
sigma1_y=args['sigma1_y'],
|
255 |
-
sigma1_z=args['sigma1_z'],
|
256 |
-
sigma2_x=args['sigma2_x'],
|
257 |
-
sigma2_y=args['sigma2_y'],
|
258 |
-
sigma2_z=args['sigma2_z'],
|
259 |
-
alpha=args['alpha'],
|
260 |
-
prob=config.transforms.prob
|
261 |
-
)
|
262 |
-
]
|
263 |
-
|
264 |
-
if 'adjust_contrast' in config.transforms.keys():
|
265 |
-
from monai.transforms import RandAdjustContrastd
|
266 |
-
args=config.transforms.adjust_contrast
|
267 |
-
tfms+=[
|
268 |
-
RandAdjustContrastd(
|
269 |
-
keys=config.data.image_cols,
|
270 |
-
gamma=args['gamma'],
|
271 |
-
prob=config.transforms.prob
|
272 |
-
)
|
273 |
-
]
|
274 |
-
|
275 |
-
# Concat mutlisequence data to single Tensors on the ChannelDim
|
276 |
-
# Rename images to `CommonKeys.IMAGE` and labels to `CommonKeys.LABELS`
|
277 |
-
# for more compatibility with monai.engines
|
278 |
-
|
279 |
-
tfms+=[
|
280 |
-
ConcatItemsd(
|
281 |
-
keys=config.data.image_cols,
|
282 |
-
name=CommonKeys.IMAGE,
|
283 |
-
dim=0
|
284 |
-
)
|
285 |
-
]
|
286 |
-
|
287 |
-
tfms+=[
|
288 |
-
ConcatItemsd(
|
289 |
-
keys=config.data.label_cols,
|
290 |
-
name=CommonKeys.LABEL,
|
291 |
-
dim=0
|
292 |
-
)
|
293 |
-
]
|
294 |
-
|
295 |
-
return Compose(tfms)
|
296 |
-
|
297 |
-
## ---------- valid transforms ----------
|
298 |
-
|
299 |
-
def get_val_transforms(config: dict):
|
300 |
-
tfms=get_base_transforms(config=config)
|
301 |
-
tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
|
302 |
-
tfms+=[
|
303 |
-
ConcatItemsd(
|
304 |
-
keys=config.data.image_cols,
|
305 |
-
name=CommonKeys.IMAGE,
|
306 |
-
dim=0
|
307 |
-
)
|
308 |
-
]
|
309 |
-
|
310 |
-
tfms+=[
|
311 |
-
ConcatItemsd(
|
312 |
-
keys=config.data.label_cols,
|
313 |
-
name=CommonKeys.LABEL,
|
314 |
-
dim=0
|
315 |
-
)
|
316 |
-
]
|
317 |
-
|
318 |
-
return Compose(tfms)
|
319 |
-
|
320 |
-
## ---------- test transforms ----------
|
321 |
-
# same as valid transforms
|
322 |
-
|
323 |
-
def get_test_transforms(config: dict):
|
324 |
-
tfms=get_base_transforms(config=config)
|
325 |
-
tfms+=[EnsureTyped(keys=config.data.image_cols+config.data.label_cols)]
|
326 |
-
tfms+=[
|
327 |
-
ConcatItemsd(
|
328 |
-
keys=config.data.image_cols,
|
329 |
-
name=CommonKeys.IMAGE,
|
330 |
-
dim=0
|
331 |
-
)
|
332 |
-
]
|
333 |
-
|
334 |
-
tfms+=[
|
335 |
-
ConcatItemsd(
|
336 |
-
keys=config.data.label_cols,
|
337 |
-
name=CommonKeys.LABEL,
|
338 |
-
dim=0
|
339 |
-
)
|
340 |
-
]
|
341 |
-
|
342 |
-
return Compose(tfms)
|
343 |
-
|
344 |
-
|
345 |
-
def get_val_post_transforms(config: dict):
|
346 |
-
tfms=[EnsureTyped(keys=[CommonKeys.PRED, CommonKeys.LABEL]),
|
347 |
-
AsDiscreted(
|
348 |
-
keys=CommonKeys.PRED,
|
349 |
-
argmax=True,
|
350 |
-
to_onehot=config.model.out_channels,
|
351 |
-
num_classes=config.model.out_channels
|
352 |
-
),
|
353 |
-
AsDiscreted(
|
354 |
-
keys=CommonKeys.LABEL,
|
355 |
-
to_onehot=config.model.out_channels,
|
356 |
-
num_classes=config.model.out_channels
|
357 |
-
),
|
358 |
-
KeepLargestConnectedComponentd(
|
359 |
-
keys=CommonKeys.PRED,
|
360 |
-
applied_labels=list(range(1, config.model.out_channels))
|
361 |
-
),
|
362 |
-
]
|
363 |
-
return Compose(tfms)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/utils.py
DELETED
@@ -1,56 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import yaml
|
3 |
-
import torch
|
4 |
-
import monai
|
5 |
-
import munch
|
6 |
-
|
7 |
-
def load_config(fn: str='config.yaml'):
|
8 |
-
"Load config from YAML and return a serialized dictionary object"
|
9 |
-
with open(fn, 'r') as stream:
|
10 |
-
config=yaml.safe_load(stream)
|
11 |
-
config=munch.munchify(config)
|
12 |
-
|
13 |
-
if not config.overwrite:
|
14 |
-
i=1
|
15 |
-
while os.path.exists(config.run_id+f'_{i}'):
|
16 |
-
i+=1
|
17 |
-
config.run_id+=f'_{i}'
|
18 |
-
|
19 |
-
config.out_dir = os.path.join(config.run_id, config.out_dir)
|
20 |
-
config.log_dir = os.path.join(config.run_id, config.log_dir)
|
21 |
-
|
22 |
-
if not isinstance(config.data.image_cols, (tuple, list)):
|
23 |
-
config.data.image_cols=[config.data.image_cols]
|
24 |
-
if not isinstance(config.data.label_cols, (tuple, list)):
|
25 |
-
config.data.label_cols=[config.data.label_cols]
|
26 |
-
|
27 |
-
config.transforms.mode=('bilinear', ) * len(config.data.image_cols) + \
|
28 |
-
('nearest', ) * len(config.data.label_cols)
|
29 |
-
return config
|
30 |
-
|
31 |
-
|
32 |
-
def num_workers():
|
33 |
-
"Get max supported workers -2 for multiprocessing"
|
34 |
-
import resource
|
35 |
-
import multiprocessing
|
36 |
-
|
37 |
-
# first check for max number of open files allowed on system
|
38 |
-
soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
39 |
-
n_workers=multiprocessing.cpu_count() - 2
|
40 |
-
# giving each worker at least 256 open processes should allow them to run smoothly
|
41 |
-
max_workers = soft_limit // 256
|
42 |
-
if max_workers < n_workers:
|
43 |
-
print(
|
44 |
-
"Will not use all available workers as number of allowed open files is to small"
|
45 |
-
"to ensure smooth multiprocessing. Current limits are:\n"
|
46 |
-
f"\t soft_limit: {soft_limit}\n"
|
47 |
-
f"\t hard_limit: {hard_limit}\n"
|
48 |
-
"try increasing the limits to at least {256*n_workers}."
|
49 |
-
"See https://superuser.com/questions/1200539/cannot-increase-open-file-limit-past-4096-ubuntu"
|
50 |
-
"for more details"
|
51 |
-
)
|
52 |
-
return max_workers
|
53 |
-
|
54 |
-
return n_workers
|
55 |
-
|
56 |
-
USE_AMP=True if monai.utils.get_torch_version_tuple() >= (1, 6) else False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prostate/viewer.py
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import ipywidgets
|
3 |
-
import numpy as np
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
from IPython.display import display
|
6 |
-
from itertools import chain, islice
|
7 |
-
from ipywidgets import interactive, widgets
|
8 |
-
|
9 |
-
def _create_label(text:str)->ipywidgets.widgets.Label:
|
10 |
-
"Create label widget"
|
11 |
-
|
12 |
-
label = widgets.Label(
|
13 |
-
text,
|
14 |
-
layout=widgets.Layout(
|
15 |
-
width='100%',
|
16 |
-
display='flex',
|
17 |
-
justify_content="center"
|
18 |
-
)
|
19 |
-
)
|
20 |
-
return label
|
21 |
-
|
22 |
-
def _create_slider(
|
23 |
-
slider_min: int,
|
24 |
-
slider_max: int,
|
25 |
-
value: int,
|
26 |
-
step: int=1,
|
27 |
-
description:str ='',
|
28 |
-
continuous_update: bool=True,
|
29 |
-
readout: bool=False,
|
30 |
-
slider_type: str='IntSlider',
|
31 |
-
**kwargs)->ipywidgets.widgets:
|
32 |
-
"Create slider widget"
|
33 |
-
|
34 |
-
slider = getattr(widgets, slider_type)(
|
35 |
-
min=slider_min,
|
36 |
-
max=slider_max,
|
37 |
-
step=step,
|
38 |
-
value=value,
|
39 |
-
description=description,
|
40 |
-
continuous_update=continuous_update,
|
41 |
-
readout = readout,
|
42 |
-
layout=widgets.Layout(width='99%', min_width='200px'),
|
43 |
-
style={'description_width': 'initial'},
|
44 |
-
**kwargs
|
45 |
-
)
|
46 |
-
return slider
|
47 |
-
|
48 |
-
def _create_button(description:str)->ipywidgets.widgets.Button:
|
49 |
-
"Create button widget"
|
50 |
-
button = widgets.Button(
|
51 |
-
description=description,
|
52 |
-
layout=widgets.Layout(
|
53 |
-
width='95%',
|
54 |
-
margin='5px 5px'
|
55 |
-
)
|
56 |
-
)
|
57 |
-
return button
|
58 |
-
|
59 |
-
def _create_togglebutton(description: str,
|
60 |
-
value: int,
|
61 |
-
**kwargs)->ipywidgets.widgets.Button:
|
62 |
-
"Create toggle button widget"
|
63 |
-
button = widgets.ToggleButton(
|
64 |
-
description=description,
|
65 |
-
value = value,
|
66 |
-
layout=widgets.Layout(
|
67 |
-
width='95%',
|
68 |
-
margin='5px 5px'
|
69 |
-
), **kwargs
|
70 |
-
)
|
71 |
-
return button
|
72 |
-
|
73 |
-
|
74 |
-
class BasicViewer():
|
75 |
-
""" Base class for viewing TensorDicom3D objects.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
x: main image object to view as rank 3 tensor
|
79 |
-
y: either a segmentation mask as as rank 3 tensor or a label as str.
|
80 |
-
prediction: a class predicton as str
|
81 |
-
description: description of the whole image
|
82 |
-
figsize: size of image, passed as plotting argument
|
83 |
-
cmap: colormap for the image
|
84 |
-
Returns:
|
85 |
-
Instance of BasicViewer
|
86 |
-
"""
|
87 |
-
|
88 |
-
def __init__(self, x:torch.Tensor, y=None, prediction:str=None, description: str=None,
|
89 |
-
figsize=(3, 3), cmap:str='bone'):
|
90 |
-
assert x.ndim == 3, f"x.ndim needs to be equal to but is {x.ndim}"
|
91 |
-
if isinstance(y, torch.Tensor):
|
92 |
-
assert x.shape == y.shape, f"Shapes of x {x.shape} and y {y.shape} do not match"
|
93 |
-
self.x=x
|
94 |
-
self.y=y
|
95 |
-
self.prediction=prediction
|
96 |
-
self.description=description
|
97 |
-
self.figsize=figsize
|
98 |
-
self.cmap=cmap
|
99 |
-
self.with_mask = isinstance(y, torch.Tensor)
|
100 |
-
self.slice_range = (1, len(x)) # len(x) == im.shape[0]
|
101 |
-
|
102 |
-
def _plot_slice(self, im_slice, with_mask, px_range):
|
103 |
-
"Plot slice of image"
|
104 |
-
fig, ax = plt.subplots(1, 1, figsize=self.figsize)
|
105 |
-
ax.imshow(self.x[im_slice-1, :, :].clip(*px_range), cmap=self.cmap)
|
106 |
-
if isinstance(self.y, (torch.Tensor)) and with_mask:
|
107 |
-
ax.imshow(self.y[im_slice-1, :, :], cmap='jet', alpha = 0.25)
|
108 |
-
plt.axis('off')
|
109 |
-
ax.set_xticks([])
|
110 |
-
ax.set_yticks([])
|
111 |
-
plt.show()
|
112 |
-
|
113 |
-
def _create_image_box(self, figsize):
|
114 |
-
"Create widget items, order them in item_box and generate view box"
|
115 |
-
items = []
|
116 |
-
|
117 |
-
if self.description: plot_description = _create_label(self.description)
|
118 |
-
|
119 |
-
if isinstance(self.y, str):
|
120 |
-
label = f'{self.y} | {self.prediction}' if self.prediction else self.y
|
121 |
-
if self.prediction:
|
122 |
-
font_color = 'green' if self.y == self.prediction else 'red'
|
123 |
-
y_label = _create_label(r'\(\color{' + font_color + '} {' + label + '}\)')
|
124 |
-
else:
|
125 |
-
y_label = _create_label(label)
|
126 |
-
else: y_label = _create_label(' ')
|
127 |
-
|
128 |
-
slice_slider = _create_slider(
|
129 |
-
slider_min = min(self.slice_range),
|
130 |
-
slider_max = max(self.slice_range),
|
131 |
-
value = max(self.slice_range)//2,
|
132 |
-
readout = True)
|
133 |
-
|
134 |
-
toggle_mask_button = _create_togglebutton('Show Mask', True)
|
135 |
-
|
136 |
-
range_slider = _create_slider(
|
137 |
-
slider_min = self.x.min().numpy(),
|
138 |
-
slider_max = self.x.max().numpy(),
|
139 |
-
value = [self.x.min().numpy(), self.x.max().numpy()],
|
140 |
-
slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
|
141 |
-
step = 0.01 if torch.is_floating_point(self.x) else 1,
|
142 |
-
readout=True)
|
143 |
-
|
144 |
-
image_output = widgets.interactive_output(
|
145 |
-
f = self._plot_slice,
|
146 |
-
controls = {'im_slice': slice_slider,
|
147 |
-
'with_mask': toggle_mask_button,
|
148 |
-
'px_range': range_slider})
|
149 |
-
|
150 |
-
image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
|
151 |
-
image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
|
152 |
-
|
153 |
-
if self.description: items.append(plot_description)
|
154 |
-
items.append(y_label)
|
155 |
-
items.append(range_slider)
|
156 |
-
items.append(image_output)
|
157 |
-
if isinstance(self.y, torch.Tensor):
|
158 |
-
slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
|
159 |
-
items.append(slice_slider)
|
160 |
-
|
161 |
-
image_box=widgets.VBox(
|
162 |
-
items,
|
163 |
-
layout = widgets.Layout(
|
164 |
-
border = 'none',
|
165 |
-
margin = '10px 5px 0px 0px',
|
166 |
-
padding = '5px'))
|
167 |
-
|
168 |
-
return image_box
|
169 |
-
|
170 |
-
def _generate_views(self):
|
171 |
-
image_box = self._create_image_box(self.figsize)
|
172 |
-
self.box = widgets.HBox(children=[image_box])
|
173 |
-
|
174 |
-
@property
|
175 |
-
def image_box(self):
|
176 |
-
return self._create_image_box(self.figsize)
|
177 |
-
|
178 |
-
def show(self):
|
179 |
-
self._generate_views()
|
180 |
-
plt.style.use('default')
|
181 |
-
display(self.box)
|
182 |
-
|
183 |
-
|
184 |
-
class DicomExplorer(BasicViewer):
|
185 |
-
""" DICOM viewer for basic image analysis inside iPython notebooks.
|
186 |
-
Can display a single 3D volume together with a segmentation mask, a histogram
|
187 |
-
of voxel/pixel values and some summary statistics.
|
188 |
-
Allows simple windowing by clipping the pixel/voxel values to a region, which
|
189 |
-
can be manually specified.
|
190 |
-
|
191 |
-
"""
|
192 |
-
|
193 |
-
vbox_layout = widgets.Layout(
|
194 |
-
margin = '10px 5px 5px 5px',
|
195 |
-
padding = '5px',
|
196 |
-
display='flex',
|
197 |
-
flex_flow='column',
|
198 |
-
align_items='center',
|
199 |
-
min_width = '250px')
|
200 |
-
|
201 |
-
def _plot_hist(self, px_range):
|
202 |
-
x = self.x.numpy().flatten()
|
203 |
-
fig, ax = plt.subplots(figsize=self.figsize)
|
204 |
-
N, bins, patches = plt.hist(x, 100, color='grey')
|
205 |
-
lwr = int(px_range[0] * 100/max(x))
|
206 |
-
upr = int(np.ceil(px_range[1] * 100/max(x)))
|
207 |
-
|
208 |
-
for i in range(0,lwr):
|
209 |
-
patches[i].set_facecolor('grey' if lwr > 0 else 'darkblue')
|
210 |
-
for i in range(lwr, upr):
|
211 |
-
patches[i].set_facecolor('darkblue')
|
212 |
-
for i in range(upr,100):
|
213 |
-
patches[i].set_facecolor('grey' if upr < 100 else 'darkblue')
|
214 |
-
|
215 |
-
plt.show()
|
216 |
-
|
217 |
-
def _image_summary(self, px_range):
|
218 |
-
x = self.x.clip(*px_range)
|
219 |
-
|
220 |
-
diffs = x - x.mean()
|
221 |
-
var = torch.mean(torch.pow(diffs, 2.0))
|
222 |
-
std = torch.pow(var, 0.5)
|
223 |
-
zscores = diffs / std
|
224 |
-
skews = torch.mean(torch.pow(zscores, 3.0))
|
225 |
-
kurt = torch.mean(torch.pow(zscores, 4.0)) - 3.0
|
226 |
-
|
227 |
-
table = f'Statistics:\n' + \
|
228 |
-
f' Mean px value: {x.mean()} \n' + \
|
229 |
-
f' Std of px values: {x.std()} \n' + \
|
230 |
-
f' Min px value: {x.min()} \n' + \
|
231 |
-
f' Max px value: {x.max()} \n' + \
|
232 |
-
f' Median px value: {x.median()} \n' + \
|
233 |
-
f' Skewness: {skews} \n' + \
|
234 |
-
f' Kurtosis: {kurt} \n\n' + \
|
235 |
-
f'Tensor properties \n' + \
|
236 |
-
f' Tensor shape: {tuple(x.shape)}\n' + \
|
237 |
-
f' Tensor dtype: {x.dtype}'
|
238 |
-
print(table)
|
239 |
-
|
240 |
-
def _generate_views(self):
|
241 |
-
|
242 |
-
slice_slider = _create_slider(
|
243 |
-
slider_min = min(self.slice_range),
|
244 |
-
slider_max = max(self.slice_range),
|
245 |
-
value = max(self.slice_range)//2,
|
246 |
-
readout = True)
|
247 |
-
|
248 |
-
toggle_mask_button = _create_togglebutton('Show Mask', True)
|
249 |
-
|
250 |
-
range_slider = _create_slider(
|
251 |
-
slider_min = self.x.min().numpy(),
|
252 |
-
slider_max = self.x.max().numpy(),
|
253 |
-
value = [self.x.min().numpy(), self.x.max().numpy()],
|
254 |
-
continuous_update=False,
|
255 |
-
slider_type = 'FloatRangeSlider' if torch.is_floating_point(self.x) else 'IntRandSlider',
|
256 |
-
step = 0.01 if torch.is_floating_point(self.x) else 1)
|
257 |
-
|
258 |
-
image_output = widgets.interactive_output(
|
259 |
-
f = self._plot_slice,
|
260 |
-
controls = {'im_slice': slice_slider,
|
261 |
-
'with_mask': toggle_mask_button,
|
262 |
-
'px_range': range_slider})
|
263 |
-
|
264 |
-
image_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
|
265 |
-
image_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
|
266 |
-
|
267 |
-
if isinstance(self.y, torch.Tensor):
|
268 |
-
slice_slider = widgets.HBox([slice_slider, toggle_mask_button])
|
269 |
-
|
270 |
-
hist_output = widgets.interactive_output(
|
271 |
-
f = self._plot_hist,
|
272 |
-
controls = {'px_range': range_slider})
|
273 |
-
|
274 |
-
hist_output.layout.height = f'{self.figsize[0]/1.2}in' # suppress flickering
|
275 |
-
hist_output.layout.width = f'{self.figsize[1]/1.2}in' # suppress flickering
|
276 |
-
|
277 |
-
toggle_mask_button = _create_togglebutton('Show Mask', True)
|
278 |
-
|
279 |
-
table_output = widgets.interactive_output(
|
280 |
-
f = self._image_summary,
|
281 |
-
controls = {'px_range': range_slider})
|
282 |
-
|
283 |
-
table_box = widgets.VBox([table_output], layout=self.vbox_layout)
|
284 |
-
|
285 |
-
hist_box = widgets.VBox(
|
286 |
-
[hist_output, range_slider],
|
287 |
-
layout=self.vbox_layout)
|
288 |
-
|
289 |
-
image_box = widgets.VBox(
|
290 |
-
[image_output, slice_slider],
|
291 |
-
layout=self.vbox_layout)
|
292 |
-
|
293 |
-
self.box = widgets.HBox(
|
294 |
-
[image_box, hist_box, table_box],
|
295 |
-
layout = widgets.Layout(
|
296 |
-
border = 'solid 1px lightgrey',
|
297 |
-
margin = '10px 5px 0px 0px',
|
298 |
-
padding = '5px',
|
299 |
-
width = f'{self.figsize[1]*2 + 3}in'))
|
300 |
-
|
301 |
-
|
302 |
-
class ListViewer(object):
|
303 |
-
""" Display multipple images with their masks or labels/predictions.
|
304 |
-
Arguments:
|
305 |
-
x (tuple, list): Tensor objects to view
|
306 |
-
y (tuple, list): Tensor objects (in case of segmentation task) or class labels as string.
|
307 |
-
predictions (str): Class predictions
|
308 |
-
cmap: colormap for display of `x`
|
309 |
-
max_n: maximum number of items to display
|
310 |
-
"""
|
311 |
-
|
312 |
-
def __init__(self, x:(list, tuple), y=None, prediction:str=None, description: str=None,
|
313 |
-
figsize=(4, 4), cmap:str='bone', max_n = 9):
|
314 |
-
self.slice_range = (1, len(x))
|
315 |
-
x = x[0:max_n]
|
316 |
-
if y: y = y[0:max_n]
|
317 |
-
self.x=x
|
318 |
-
self.y=y
|
319 |
-
self.prediction=prediction
|
320 |
-
self.description=description
|
321 |
-
self.figsize=figsize
|
322 |
-
self.cmap=cmap
|
323 |
-
self.max_n=max_n
|
324 |
-
|
325 |
-
def _generate_views(self):
|
326 |
-
n_images = len(self.x)
|
327 |
-
image_grid, image_list = [], []
|
328 |
-
|
329 |
-
for i in range(0, n_images):
|
330 |
-
image = self.x[i]
|
331 |
-
mask = self.y[i] if isinstance(self.y, list) else None
|
332 |
-
pred = self.prediction[i] if self.prediction else None
|
333 |
-
|
334 |
-
image_list.append(
|
335 |
-
BasicViewer(
|
336 |
-
x = image,
|
337 |
-
y = mask,
|
338 |
-
prediction = pred,
|
339 |
-
figsize = self.figsize,
|
340 |
-
cmap = self.cmap)
|
341 |
-
.image_box)
|
342 |
-
|
343 |
-
if (i+1) % np.ceil(np.sqrt(n_images)) == 0 or i == n_images - 1:
|
344 |
-
image_grid.append(widgets.HBox(image_list))
|
345 |
-
image_list = []
|
346 |
-
|
347 |
-
self.box = widgets.VBox(children=image_grid)
|
348 |
-
|
349 |
-
def show(self):
|
350 |
-
self._generate_views()
|
351 |
-
plt.style.use('default')
|
352 |
-
display(self.box)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|