Commit
·
78ae283
1
Parent(s):
286a978
Updating latest changes DDMR
Browse files- DeepDeformationMapRegistration/layers/augmentation.py +1 -1
- DeepDeformationMapRegistration/layers/upsampling.py +2 -0
- DeepDeformationMapRegistration/utils/constants.py +3 -2
- DeepDeformationMapRegistration/utils/misc.py +69 -10
- DeepDeformationMapRegistration/utils/operators.py +15 -0
- DeepDeformationMapRegistration/utils/visualization.py +43 -14
- requirements.txt +12 -7
DeepDeformationMapRegistration/layers/augmentation.py
CHANGED
@@ -133,7 +133,7 @@ class AugmentationLayer(kl.Layer):
|
|
133 |
mov_img = tf.zeros_like(fix_img)
|
134 |
mov_segm = tf.zeros_like(fix_segm)
|
135 |
|
136 |
-
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3])
|
137 |
|
138 |
if self.out_img_shape is not None:
|
139 |
fix_img = self.downsize_image(fix_img)
|
|
|
133 |
mov_img = tf.zeros_like(fix_img)
|
134 |
mov_segm = tf.zeros_like(fix_segm)
|
135 |
|
136 |
+
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3]) # TODO: change, don't use tile!!
|
137 |
|
138 |
if self.out_img_shape is not None:
|
139 |
fix_img = self.downsize_image(fix_img)
|
DeepDeformationMapRegistration/layers/upsampling.py
CHANGED
@@ -485,6 +485,8 @@ def UpInterpolate3D(x,
|
|
485 |
nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
|
486 |
elif data_format == 'channels_first':
|
487 |
nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
|
|
|
|
|
488 |
|
489 |
r = size[0]
|
490 |
c = size[1]
|
|
|
485 |
nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
|
486 |
elif data_format == 'channels_first':
|
487 |
nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
|
488 |
+
else:
|
489 |
+
raise ValueError('Invalid option: ', data_format)
|
490 |
|
491 |
r = size[0]
|
492 |
c = size[1]
|
DeepDeformationMapRegistration/utils/constants.py
CHANGED
@@ -413,6 +413,7 @@ WAR = 30 # Warning
|
|
413 |
ERR = 40 # Error
|
414 |
DEB = 10 # Debug
|
415 |
CRI = 50 # Critical
|
|
|
416 |
|
417 |
SEVERITY_STR = {INF: 'INFO',
|
418 |
WAR: 'WARNING',
|
@@ -511,8 +512,8 @@ REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
|
|
511 |
IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
|
512 |
# ...OSLO_COMET_CT/Formatted_128x128x128/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
513 |
COMET_DATASET_iso_to_cubic_scales = np.asarray([0.455259 + 0.048027, 0.492012 + 0.044298, 0.577552 + 0.051708])
|
514 |
-
MAX_AUG_DISP_ISOT = 30
|
515 |
-
MAX_AUG_DEF_ISOT = 6
|
516 |
MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
|
517 |
MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
|
518 |
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
|
|
413 |
ERR = 40 # Error
|
414 |
DEB = 10 # Debug
|
415 |
CRI = 50 # Critical
|
416 |
+
SUMMARY_LINE_LENGTH = 150
|
417 |
|
418 |
SEVERITY_STR = {INF: 'INFO',
|
419 |
WAR: 'WARNING',
|
|
|
512 |
IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
|
513 |
# ...OSLO_COMET_CT/Formatted_128x128x128/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
|
514 |
COMET_DATASET_iso_to_cubic_scales = np.asarray([0.455259 + 0.048027, 0.492012 + 0.044298, 0.577552 + 0.051708])
|
515 |
+
MAX_AUG_DISP_ISOT = 30 # mm
|
516 |
+
MAX_AUG_DEF_ISOT = 6 # mm
|
517 |
MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
|
518 |
MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
|
519 |
MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
|
DeepDeformationMapRegistration/utils/misc.py
CHANGED
@@ -8,6 +8,7 @@ from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
|
|
8 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
from tensorflow import squeeze
|
10 |
from scipy.ndimage import zoom
|
|
|
11 |
|
12 |
|
13 |
def try_mkdir(dir, verbose=True):
|
@@ -119,15 +120,22 @@ class DisplacementMapInterpolator:
|
|
119 |
return disp
|
120 |
|
121 |
|
122 |
-
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(
|
123 |
segmentations = np.squeeze(segmentations)
|
124 |
if ohe:
|
125 |
-
segmentations =
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
129 |
else:
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
seg_props = regionprops(segmentations)
|
133 |
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
|
@@ -147,11 +155,15 @@ def segmentation_ohe_to_cardinal(segmentation):
|
|
147 |
return np.argmax(cpy, axis=-1)[..., np.newaxis]
|
148 |
|
149 |
|
150 |
-
def segmentation_cardinal_to_ohe(segmentation):
|
151 |
# Keep in mind that we don't handle the overlap between the segmentations!
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
155 |
return cpy
|
156 |
|
157 |
|
@@ -184,3 +196,50 @@ def scale_transformation(original_shape: [list, tuple, np.ndarray], dest_shape:
|
|
184 |
|
185 |
return trf
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
|
9 |
from tensorflow import squeeze
|
10 |
from scipy.ndimage import zoom
|
11 |
+
import tensorflow as tf
|
12 |
|
13 |
|
14 |
def try_mkdir(dir, verbose=True):
|
|
|
120 |
return disp
|
121 |
|
122 |
|
123 |
+
def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(1, 28), missing_centroid=[np.nan]*3, brain_study=True):
|
124 |
segmentations = np.squeeze(segmentations)
|
125 |
if ohe:
|
126 |
+
segmentations = segmentation_ohe_to_cardinal(segmentations)
|
127 |
+
lbls = set(np.unique(segmentations)) - {0} # Remove the 0 value returned by np.unique, no label
|
128 |
+
# missing_lbls = set(expected_lbls) - lbls
|
129 |
+
# if brain_study:
|
130 |
+
# segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1
|
131 |
else:
|
132 |
+
lbls = set(np.unique(segmentations)) if 0 in expected_lbls else set(np.unique(segmentations)) - {0}
|
133 |
+
missing_lbls = set(expected_lbls) - lbls
|
134 |
+
|
135 |
+
if 0 in expected_lbls:
|
136 |
+
segmentations += np.ones_like(segmentations) # Regionsprops neglects the label 0. But we need it, so offset all labels by 1
|
137 |
+
|
138 |
+
segmentations = np.squeeze(segmentations) # remove channel dimension, not needed anyway
|
139 |
|
140 |
seg_props = regionprops(segmentations)
|
141 |
centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
|
|
|
155 |
return np.argmax(cpy, axis=-1)[..., np.newaxis]
|
156 |
|
157 |
|
158 |
+
def segmentation_cardinal_to_ohe(segmentation, labels_list: list = None):
|
159 |
# Keep in mind that we don't handle the overlap between the segmentations!
|
160 |
+
#labels_list = np.unique(segmentation)[1:] if labels_list is None else labels_list
|
161 |
+
num_labels = len(labels_list)
|
162 |
+
expected_shape = segmentation.shape[:-1] + (num_labels,)
|
163 |
+
cpy = np.zeros(expected_shape, dtype=np.uint8)
|
164 |
+
seg_squeezed = np.squeeze(segmentation, axis=-1)
|
165 |
+
for ch, lbl in enumerate(labels_list):
|
166 |
+
cpy[seg_squeezed == lbl, ch] = 1
|
167 |
return cpy
|
168 |
|
169 |
|
|
|
196 |
|
197 |
return trf
|
198 |
|
199 |
+
|
200 |
+
class GaussianFilter:
|
201 |
+
def __init__(self, size, sigma, dim, num_channels, stride=None, batch: bool=True):
|
202 |
+
"""
|
203 |
+
Gaussian filter
|
204 |
+
:param size: Kernel size
|
205 |
+
:param sigma: Sigma of the Gaussian filter.
|
206 |
+
:param dim: Data dimensionality. Must be {2, 3}.
|
207 |
+
:param num_channels: Number of channels of the image to filter.
|
208 |
+
"""
|
209 |
+
self.size = size
|
210 |
+
self.dim = dim
|
211 |
+
self.sigma = float(sigma)
|
212 |
+
self.num_channels = num_channels
|
213 |
+
self.stride = size // 2 if stride is None else int(stride)
|
214 |
+
if batch:
|
215 |
+
self.stride = [1] + [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims
|
216 |
+
else:
|
217 |
+
self.stride = [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims
|
218 |
+
|
219 |
+
self.convDN = getattr(tf.nn, 'conv%dd' % dim)
|
220 |
+
self.__GF = None
|
221 |
+
|
222 |
+
self.__build_gaussian_filter()
|
223 |
+
|
224 |
+
def __build_gaussian_filter(self):
|
225 |
+
range_1d = tf.range(-(self.size/2) + 1, self.size//2 + 1)
|
226 |
+
g_1d = tf.math.exp(-1.0 * tf.pow(range_1d, 2) / (2. * tf.pow(self.sigma, 2)))
|
227 |
+
g_1d_expanded = tf.expand_dims(g_1d, -1)
|
228 |
+
iterator = tf.constant(1)
|
229 |
+
self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
|
230 |
+
lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
|
231 |
+
[iterator, g_1d],
|
232 |
+
[iterator.get_shape(), tf.TensorShape(None)], # Shape invariants
|
233 |
+
back_prop=False
|
234 |
+
)[-1]
|
235 |
+
|
236 |
+
self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
|
237 |
+
self.__GF = tf.reshape(self.__GF, (*[self.size]*self.dim, 1, 1)) # Add Ch_in and Ch_out for convolution
|
238 |
+
self.__GF = tf.tile(self.__GF, (*[1] * self.dim, self.num_channels, self.num_channels,))
|
239 |
+
|
240 |
+
def apply_filter(self, in_image):
|
241 |
+
return self.convDN(in_image, self.__GF, self.stride, 'SAME')
|
242 |
+
|
243 |
+
@property
|
244 |
+
def kernel(self):
|
245 |
+
return self.__GF
|
DeepDeformationMapRegistration/utils/operators.py
CHANGED
@@ -63,3 +63,18 @@ def sample_unique(population, samples, tout=tf.int32):
|
|
63 |
_, indices = tf.nn.top_k(z, samples)
|
64 |
ret_val = tf.gather(population, indices)
|
65 |
return tf.cast(ret_val, tout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
_, indices = tf.nn.top_k(z, samples)
|
64 |
ret_val = tf.gather(population, indices)
|
65 |
return tf.cast(ret_val, tout)
|
66 |
+
|
67 |
+
|
68 |
+
def safe_medpy_metric(prediction, reference, nb_labels, metric_fnc, fnc_args: dict={}):
|
69 |
+
vals = list()
|
70 |
+
if 'voxelspacing' in fnc_args.keys():
|
71 |
+
diag = np.power(reference.shape[:-1] * fnc_args['voxelspacing'], 2)
|
72 |
+
else:
|
73 |
+
diag = np.power(reference.shape[:-1], 2)
|
74 |
+
diag = np.sqrt(np.sum(diag))
|
75 |
+
for l in range(nb_labels):
|
76 |
+
try:
|
77 |
+
vals.append(metric_fnc(prediction[..., l], reference[..., l], **fnc_args))
|
78 |
+
except RuntimeError:
|
79 |
+
vals.append(diag)
|
80 |
+
return vals
|
DeepDeformationMapRegistration/utils/visualization.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import matplotlib
|
2 |
-
|
3 |
import matplotlib.pyplot as plt
|
4 |
from mpl_toolkits.mplot3d import Axes3D
|
5 |
import matplotlib.colors as mcolors
|
@@ -17,7 +17,7 @@ THRES = 0.9
|
|
17 |
|
18 |
# COLOR MAPS
|
19 |
chunks = np.linspace(0, 1, 10)
|
20 |
-
cmap1 = plt.get_cmap('hsv',
|
21 |
# cmaplist = [cmap1(i) for i in range(cmap1.N)]
|
22 |
cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
|
23 |
cmaplist[0] = (1, 1, 1, 1.0)
|
@@ -34,6 +34,14 @@ cmap4 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
|
|
34 |
|
35 |
cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
|
39 |
if dimensionality == 2:
|
@@ -321,7 +329,7 @@ def save_centreline_img(img, title, filename, fig=None):
|
|
321 |
plt.close()
|
322 |
|
323 |
|
324 |
-
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False):
|
325 |
if fig is not None:
|
326 |
fig.clear()
|
327 |
plt.figure(fig.number)
|
@@ -333,7 +341,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
333 |
if dim == 2:
|
334 |
ax_x = fig.add_subplot(131)
|
335 |
ax_x.set_title('H displacement')
|
336 |
-
im_x = ax_x.imshow(disp_map[..., C.H_DISP])
|
337 |
ax_x.tick_params(axis='both',
|
338 |
which='both',
|
339 |
bottom=False,
|
@@ -344,7 +352,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
344 |
|
345 |
ax_y = fig.add_subplot(132)
|
346 |
ax_y.set_title('W displacement')
|
347 |
-
im_y = ax_y.imshow(disp_map[..., C.W_DISP])
|
348 |
ax_y.tick_params(axis='both',
|
349 |
which='both',
|
350 |
bottom=False,
|
@@ -371,7 +379,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
371 |
ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
372 |
|
373 |
else:
|
374 |
-
c, d, s = _prepare_quiver_map(disp_map, dim=dim)
|
375 |
im = ax.imshow(s, interpolation='none', aspect='equal')
|
376 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
|
377 |
scale=C.QUIVER_PARAMS.arrow_scale)
|
@@ -386,7 +394,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
|
|
386 |
fig.suptitle(title)
|
387 |
else:
|
388 |
ax = fig.add_subplot(111, projection='3d')
|
389 |
-
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
|
390 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
|
391 |
_square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
|
392 |
fig.suptitle('Displacement map')
|
@@ -810,7 +818,12 @@ def plot_dataset_3d(img_sets):
|
|
810 |
return fig
|
811 |
|
812 |
|
813 |
-
def plot_predictions(
|
|
|
|
|
|
|
|
|
|
|
814 |
num_rows = fix_img_batch.shape[0]
|
815 |
img_dim = len(fix_img_batch.shape) - 2
|
816 |
img_size = fix_img_batch.shape[1:-1]
|
@@ -828,6 +841,10 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
828 |
fix_img_batch = fix_img_batch[:, selected_slice, ...]
|
829 |
mov_img_batch = mov_img_batch[:, selected_slice, ...]
|
830 |
pred_img_batch = pred_img_batch[:, selected_slice, ...]
|
|
|
|
|
|
|
|
|
831 |
disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
|
832 |
img_size = fix_img_batch.shape[1:-1]
|
833 |
elif img_dim != 2:
|
@@ -836,16 +853,24 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
836 |
for row in range(num_rows):
|
837 |
fix_img = fix_img_batch[row, :, :, 0].transpose()
|
838 |
mov_img = mov_img_batch[row, :, :, 0].transpose()
|
839 |
-
disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
|
840 |
pred_img = pred_img_batch[row, :, :, 0].transpose()
|
841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
ax[row, 0].tick_params(axis='both',
|
843 |
which='both',
|
844 |
bottom=False,
|
845 |
left=False,
|
846 |
labelleft=False,
|
847 |
labelbottom=False)
|
848 |
-
ax[row, 1].imshow(mov_img, origin='lower')
|
|
|
|
|
849 |
ax[row, 1].tick_params(axis='both',
|
850 |
which='both',
|
851 |
bottom=False,
|
@@ -853,7 +878,7 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
853 |
labelleft=False,
|
854 |
labelbottom=False)
|
855 |
|
856 |
-
c, d, s = _prepare_quiver_map(disp_map, spc=
|
857 |
cx, cy = c
|
858 |
dx, dy = d
|
859 |
disp_map_color = _prepare_colormap(disp_map)
|
@@ -866,7 +891,9 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
866 |
labelleft=False,
|
867 |
labelbottom=False)
|
868 |
|
869 |
-
ax[row, 3].imshow(mov_img, origin='lower')
|
|
|
|
|
870 |
ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
|
871 |
ax[row, 3].tick_params(axis='both',
|
872 |
which='both',
|
@@ -875,7 +902,9 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
|
|
875 |
labelleft=False,
|
876 |
labelbottom=False)
|
877 |
|
878 |
-
ax[row, 4].imshow(pred_img, origin='lower')
|
|
|
|
|
879 |
ax[row, 4].tick_params(axis='both',
|
880 |
which='both',
|
881 |
bottom=False,
|
|
|
1 |
import matplotlib
|
2 |
+
matplotlib.use('WebAgg')
|
3 |
import matplotlib.pyplot as plt
|
4 |
from mpl_toolkits.mplot3d import Axes3D
|
5 |
import matplotlib.colors as mcolors
|
|
|
17 |
|
18 |
# COLOR MAPS
|
19 |
chunks = np.linspace(0, 1, 10)
|
20 |
+
cmap1 = plt.get_cmap('hsv', 30)
|
21 |
# cmaplist = [cmap1(i) for i in range(cmap1.N)]
|
22 |
cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
|
23 |
cmaplist[0] = (1, 1, 1, 1.0)
|
|
|
34 |
|
35 |
cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
|
36 |
|
37 |
+
cmap_segs = np.asarray([mcolors.to_rgba(mcolors.CSS4_COLORS[c], 1) for c in mcolors.CSS4_COLORS.keys()])
|
38 |
+
cmap_segs.sort()
|
39 |
+
# rnd_idxs = [30, 17, 72, 90, 74, 39, 120, 63, 52, 79, 140, 68, 131, 109, 57, 49, 11, 132, 29, 46, 51, 26, 53, 7, 89, 47, 43, 121, 31, 28, 106, 92, 130, 117, 91, 118, 61, 5, 80, 93, 58, 133, 14, 98, 116, 76, 113, 111, 136, 142, 95, 122, 86, 77, 36, 97, 141, 115, 18, 81, 88, 87, 44, 146, 103, 67, 147, 48, 42, 83, 128, 65, 139, 69, 27, 135, 94, 134, 50, 19, 114, 0, 96, 10, 138, 75, 13, 12, 102, 32, 66, 16, 8, 73, 85, 145, 54, 37, 70, 143]
|
40 |
+
# cmap_segs = cmap_segs[rnd_idxs]
|
41 |
+
np.random.shuffle(cmap_segs)
|
42 |
+
cmap_segs[0, -1] = 0
|
43 |
+
cmap_segs = mcolors.ListedColormap(cmap_segs)
|
44 |
+
|
45 |
|
46 |
def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
|
47 |
if dimensionality == 2:
|
|
|
329 |
plt.close()
|
330 |
|
331 |
|
332 |
+
def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False, step=1):
|
333 |
if fig is not None:
|
334 |
fig.clear()
|
335 |
plt.figure(fig.number)
|
|
|
341 |
if dim == 2:
|
342 |
ax_x = fig.add_subplot(131)
|
343 |
ax_x.set_title('H displacement')
|
344 |
+
im_x = ax_x.imshow(disp_map[..., ::step, ::step, C.H_DISP])
|
345 |
ax_x.tick_params(axis='both',
|
346 |
which='both',
|
347 |
bottom=False,
|
|
|
352 |
|
353 |
ax_y = fig.add_subplot(132)
|
354 |
ax_y.set_title('W displacement')
|
355 |
+
im_y = ax_y.imshow(disp_map[..., ::step, ::step, C.W_DISP])
|
356 |
ax_y.tick_params(axis='both',
|
357 |
which='both',
|
358 |
bottom=False,
|
|
|
379 |
ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
|
380 |
|
381 |
else:
|
382 |
+
c, d, s = _prepare_quiver_map(disp_map, dim=dim, spc=step)
|
383 |
im = ax.imshow(s, interpolation='none', aspect='equal')
|
384 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
|
385 |
scale=C.QUIVER_PARAMS.arrow_scale)
|
|
|
394 |
fig.suptitle(title)
|
395 |
else:
|
396 |
ax = fig.add_subplot(111, projection='3d')
|
397 |
+
c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim, spc=step)
|
398 |
ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
|
399 |
_square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
|
400 |
fig.suptitle('Displacement map')
|
|
|
818 |
return fig
|
819 |
|
820 |
|
821 |
+
def plot_predictions(img_batches, disp_map_batch, seg_batches=None, step=1, filename='predictions', fig=None, show=False):
|
822 |
+
fix_img_batch, mov_img_batch, pred_img_batch = img_batches
|
823 |
+
if seg_batches != None:
|
824 |
+
fix_seg_batch, mov_seg_batch, pred_seg_batch = seg_batches
|
825 |
+
else:
|
826 |
+
fix_seg_batch = mov_seg_batch = pred_seg_batch = None
|
827 |
num_rows = fix_img_batch.shape[0]
|
828 |
img_dim = len(fix_img_batch.shape) - 2
|
829 |
img_size = fix_img_batch.shape[1:-1]
|
|
|
841 |
fix_img_batch = fix_img_batch[:, selected_slice, ...]
|
842 |
mov_img_batch = mov_img_batch[:, selected_slice, ...]
|
843 |
pred_img_batch = pred_img_batch[:, selected_slice, ...]
|
844 |
+
if seg_batches != None:
|
845 |
+
fix_seg_batch = fix_seg_batch[:, selected_slice, ...]
|
846 |
+
mov_seg_batch = mov_seg_batch[:, selected_slice, ...]
|
847 |
+
pred_seg_batch = pred_seg_batch[:, selected_slice, ...]
|
848 |
disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
|
849 |
img_size = fix_img_batch.shape[1:-1]
|
850 |
elif img_dim != 2:
|
|
|
853 |
for row in range(num_rows):
|
854 |
fix_img = fix_img_batch[row, :, :, 0].transpose()
|
855 |
mov_img = mov_img_batch[row, :, :, 0].transpose()
|
|
|
856 |
pred_img = pred_img_batch[row, :, :, 0].transpose()
|
857 |
+
if seg_batches != None:
|
858 |
+
fix_seg = fix_seg_batch[row, :, :, 0].transpose()
|
859 |
+
mov_seg= mov_seg_batch[row, :, :, 0].transpose()
|
860 |
+
pred_seg = pred_seg_batch[row, :, :, 0].transpose()
|
861 |
+
disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
|
862 |
+
ax[row, 0].imshow(fix_img, origin='lower', cmap='gray')
|
863 |
+
if seg_batches != None:
|
864 |
+
ax[row, 0].imshow(fix_seg, origin='lower', cmap=cmap_segs)
|
865 |
ax[row, 0].tick_params(axis='both',
|
866 |
which='both',
|
867 |
bottom=False,
|
868 |
left=False,
|
869 |
labelleft=False,
|
870 |
labelbottom=False)
|
871 |
+
ax[row, 1].imshow(mov_img, origin='lower', cmap='gray')
|
872 |
+
if seg_batches != None:
|
873 |
+
ax[row, 1].imshow(mov_seg, origin='lower', cmap=cmap_segs)
|
874 |
ax[row, 1].tick_params(axis='both',
|
875 |
which='both',
|
876 |
bottom=False,
|
|
|
878 |
labelleft=False,
|
879 |
labelbottom=False)
|
880 |
|
881 |
+
c, d, s = _prepare_quiver_map(disp_map, spc=step)
|
882 |
cx, cy = c
|
883 |
dx, dy = d
|
884 |
disp_map_color = _prepare_colormap(disp_map)
|
|
|
891 |
labelleft=False,
|
892 |
labelbottom=False)
|
893 |
|
894 |
+
ax[row, 3].imshow(mov_img, origin='lower', cmap='gray')
|
895 |
+
if seg_batches != None:
|
896 |
+
ax[row, 3].imshow(mov_seg, origin='lower', cmap=cmap_segs)
|
897 |
ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
|
898 |
ax[row, 3].tick_params(axis='both',
|
899 |
which='both',
|
|
|
902 |
labelleft=False,
|
903 |
labelbottom=False)
|
904 |
|
905 |
+
ax[row, 4].imshow(pred_img, origin='lower', cmap='gray')
|
906 |
+
if seg_batches != None:
|
907 |
+
ax[row, 4].imshow(pred_seg, origin='lower', cmap=cmap_segs)
|
908 |
ax[row, 4].tick_params(axis='both',
|
909 |
which='both',
|
910 |
bottom=False,
|
requirements.txt
CHANGED
@@ -27,20 +27,23 @@ et-xmlfile==1.0.1
|
|
27 |
fastrlock==0.6
|
28 |
flatbuffers==1.12
|
29 |
future==0.18.2
|
30 |
-
gast==0.
|
31 |
google-auth==1.35.0
|
32 |
google-auth-oauthlib==0.4.6
|
33 |
google-pasta==0.2.0
|
34 |
googleapis-common-protos==1.53.0
|
35 |
grpcio==1.40.0
|
36 |
-
h5py==
|
37 |
idna==2.10
|
38 |
imageio==2.9.0
|
39 |
importlib-metadata==3.4.0
|
40 |
importlib-resources==5.2.2
|
|
|
41 |
ipykernel==5.5.3
|
42 |
ipython==7.16.1
|
43 |
ipython-genutils==0.2.0
|
|
|
|
|
44 |
ipywidgets==7.6.3
|
45 |
jedi==0.18.0
|
46 |
Jinja2==2.11.3
|
@@ -84,7 +87,7 @@ patsy==0.5.1
|
|
84 |
pexpect==4.8.0
|
85 |
pickleshare==0.7.5
|
86 |
Pillow==8.1.0
|
87 |
-
|
88 |
plotly==4.14.3
|
89 |
plyfile==0.7.3
|
90 |
probreg==0.3.1
|
@@ -107,6 +110,7 @@ pyrsistent==0.17.3
|
|
107 |
pystrum==0.1
|
108 |
python-dateutil==2.8.1
|
109 |
python-utils==2.5.6
|
|
|
110 |
pytz==2021.1
|
111 |
PyWavelets==1.1.1
|
112 |
PyYAML==5.4.1
|
@@ -123,13 +127,13 @@ SimpleITK==2.0.2
|
|
123 |
six==1.15.0
|
124 |
sklearn==0.0
|
125 |
statsmodels==0.12.2
|
126 |
-
|
|
|
127 |
tensorboard-data-server==0.6.1
|
128 |
-
tensorboard-plugin-wit==1.8.0
|
129 |
tensorflow-addons==0.14.0
|
130 |
tensorflow-datasets==4.4.0
|
131 |
-
tensorflow-estimator==1.
|
132 |
-
tensorflow-gpu==1.
|
133 |
tensorflow-metadata==1.2.0
|
134 |
termcolor==1.1.0
|
135 |
terminado==0.9.4
|
@@ -141,6 +145,7 @@ tikzplotlib==0.9.7
|
|
141 |
tornado==6.1
|
142 |
tqdm==4.56.0
|
143 |
traitlets==4.3.3
|
|
|
144 |
transformations==2020.1.1
|
145 |
trimesh==3.9.29
|
146 |
typeguard==2.12.1
|
|
|
27 |
fastrlock==0.6
|
28 |
flatbuffers==1.12
|
29 |
future==0.18.2
|
30 |
+
gast==0.2.2
|
31 |
google-auth==1.35.0
|
32 |
google-auth-oauthlib==0.4.6
|
33 |
google-pasta==0.2.0
|
34 |
googleapis-common-protos==1.53.0
|
35 |
grpcio==1.40.0
|
36 |
+
h5py==2.10.0
|
37 |
idna==2.10
|
38 |
imageio==2.9.0
|
39 |
importlib-metadata==3.4.0
|
40 |
importlib-resources==5.2.2
|
41 |
+
ipydatawidgets==4.2.0
|
42 |
ipykernel==5.5.3
|
43 |
ipython==7.16.1
|
44 |
ipython-genutils==0.2.0
|
45 |
+
ipyvolume==0.5.2
|
46 |
+
ipywebrtc==0.6.0
|
47 |
ipywidgets==7.6.3
|
48 |
jedi==0.18.0
|
49 |
Jinja2==2.11.3
|
|
|
87 |
pexpect==4.8.0
|
88 |
pickleshare==0.7.5
|
89 |
Pillow==8.1.0
|
90 |
+
pkg_resources==0.0.0
|
91 |
plotly==4.14.3
|
92 |
plyfile==0.7.3
|
93 |
probreg==0.3.1
|
|
|
110 |
pystrum==0.1
|
111 |
python-dateutil==2.8.1
|
112 |
python-utils==2.5.6
|
113 |
+
pythreejs==2.3.0
|
114 |
pytz==2021.1
|
115 |
PyWavelets==1.1.1
|
116 |
PyYAML==5.4.1
|
|
|
127 |
six==1.15.0
|
128 |
sklearn==0.0
|
129 |
statsmodels==0.12.2
|
130 |
+
tabulate==0.8.9
|
131 |
+
tensorboard==1.14.0
|
132 |
tensorboard-data-server==0.6.1
|
|
|
133 |
tensorflow-addons==0.14.0
|
134 |
tensorflow-datasets==4.4.0
|
135 |
+
tensorflow-estimator==1.14.0
|
136 |
+
tensorflow-gpu==1.14.0
|
137 |
tensorflow-metadata==1.2.0
|
138 |
termcolor==1.1.0
|
139 |
terminado==0.9.4
|
|
|
145 |
tornado==6.1
|
146 |
tqdm==4.56.0
|
147 |
traitlets==4.3.3
|
148 |
+
traittypes==0.2.1
|
149 |
transformations==2020.1.1
|
150 |
trimesh==3.9.29
|
151 |
typeguard==2.12.1
|