jpdefrutos commited on
Commit
78ae283
·
1 Parent(s): 286a978

Updating latest changes DDMR

Browse files
DeepDeformationMapRegistration/layers/augmentation.py CHANGED
@@ -133,7 +133,7 @@ class AugmentationLayer(kl.Layer):
133
  mov_img = tf.zeros_like(fix_img)
134
  mov_segm = tf.zeros_like(fix_segm)
135
 
136
- disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3])
137
 
138
  if self.out_img_shape is not None:
139
  fix_img = self.downsize_image(fix_img)
 
133
  mov_img = tf.zeros_like(fix_img)
134
  mov_segm = tf.zeros_like(fix_segm)
135
 
136
+ disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3]) # TODO: change, don't use tile!!
137
 
138
  if self.out_img_shape is not None:
139
  fix_img = self.downsize_image(fix_img)
DeepDeformationMapRegistration/layers/upsampling.py CHANGED
@@ -485,6 +485,8 @@ def UpInterpolate3D(x,
485
  nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
486
  elif data_format == 'channels_first':
487
  nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
 
 
488
 
489
  r = size[0]
490
  c = size[1]
 
485
  nb, nr, nc, nd, nh = tf.TensorShape(x).as_list()
486
  elif data_format == 'channels_first':
487
  nb, nh, nr, nc, nd = tf.TensorShape(x).as_list()
488
+ else:
489
+ raise ValueError('Invalid option: ', data_format)
490
 
491
  r = size[0]
492
  c = size[1]
DeepDeformationMapRegistration/utils/constants.py CHANGED
@@ -413,6 +413,7 @@ WAR = 30 # Warning
413
  ERR = 40 # Error
414
  DEB = 10 # Debug
415
  CRI = 50 # Critical
 
416
 
417
  SEVERITY_STR = {INF: 'INFO',
418
  WAR: 'WARNING',
@@ -511,8 +512,8 @@ REG_MANUAL_W = [1.] * len(REG_PRIOR_W)
511
  IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
512
  # ...OSLO_COMET_CT/Formatted_128x128x128/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
513
  COMET_DATASET_iso_to_cubic_scales = np.asarray([0.455259 + 0.048027, 0.492012 + 0.044298, 0.577552 + 0.051708])
514
- MAX_AUG_DISP_ISOT = 30
515
- MAX_AUG_DEF_ISOT = 6
516
  MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
517
  MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
518
  MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
 
413
  ERR = 40 # Error
414
  DEB = 10 # Debug
415
  CRI = 50 # Critical
416
+ SUMMARY_LINE_LENGTH = 150
417
 
418
  SEVERITY_STR = {INF: 'INFO',
419
  WAR: 'WARNING',
 
512
  IXI_DATASET_iso_to_cubic_scales = np.asarray([0.655491 + 0.039223, 0.496783 + 0.029349, 0.499691 + 0.028155])
513
  # ...OSLO_COMET_CT/Formatted_128x128x128/zoom_factors.csv contain the scale factors of all the training samples from isotropic to 128x128x128
514
  COMET_DATASET_iso_to_cubic_scales = np.asarray([0.455259 + 0.048027, 0.492012 + 0.044298, 0.577552 + 0.051708])
515
+ MAX_AUG_DISP_ISOT = 30 # mm
516
+ MAX_AUG_DEF_ISOT = 6 # mm
517
  MAX_AUG_DISP = np.max(MAX_AUG_DISP_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled displacements
518
  MAX_AUG_DEF = np.max(MAX_AUG_DEF_ISOT * IXI_DATASET_iso_to_cubic_scales) # Scaled deformations
519
  MAX_AUG_ANGLE = np.max([np.arctan(np.tan(10*np.pi/180) * IXI_DATASET_iso_to_cubic_scales[1] / IXI_DATASET_iso_to_cubic_scales[0]) * 180 / np.pi,
DeepDeformationMapRegistration/utils/misc.py CHANGED
@@ -8,6 +8,7 @@ from DeepDeformationMapRegistration.layers.b_splines import interpolate_spline
8
  from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
9
  from tensorflow import squeeze
10
  from scipy.ndimage import zoom
 
11
 
12
 
13
  def try_mkdir(dir, verbose=True):
@@ -119,15 +120,22 @@ class DisplacementMapInterpolator:
119
  return disp
120
 
121
 
122
- def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(0, 28), missing_centroid=[np.nan]*3, brain_study=True):
123
  segmentations = np.squeeze(segmentations)
124
  if ohe:
125
- segmentations = np.sum(segmentations, axis=-1).astype(np.uint8)
126
- missing_lbls = set(expected_lbls) - set(np.unique(segmentations))
127
- if brain_study:
128
- segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1
 
129
  else:
130
- missing_lbls = set(expected_lbls) - set(np.unique(segmentations))
 
 
 
 
 
 
131
 
132
  seg_props = regionprops(segmentations)
133
  centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
@@ -147,11 +155,15 @@ def segmentation_ohe_to_cardinal(segmentation):
147
  return np.argmax(cpy, axis=-1)[..., np.newaxis]
148
 
149
 
150
- def segmentation_cardinal_to_ohe(segmentation):
151
  # Keep in mind that we don't handle the overlap between the segmentations!
152
- cpy = np.tile(np.zeros_like(segmentation), (1, 1, 1, len(np.unique(segmentation)[1:])))
153
- for ch, lbl in enumerate(np.unique(segmentation)[1:]):
154
- cpy[segmentation == lbl, ch] = 1
 
 
 
 
155
  return cpy
156
 
157
 
@@ -184,3 +196,50 @@ def scale_transformation(original_shape: [list, tuple, np.ndarray], dest_shape:
184
 
185
  return trf
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from DeepDeformationMapRegistration.utils.thin_plate_splines import ThinPlateSplines
9
  from tensorflow import squeeze
10
  from scipy.ndimage import zoom
11
+ import tensorflow as tf
12
 
13
 
14
  def try_mkdir(dir, verbose=True):
 
120
  return disp
121
 
122
 
123
+ def get_segmentations_centroids(segmentations, ohe=True, expected_lbls=range(1, 28), missing_centroid=[np.nan]*3, brain_study=True):
124
  segmentations = np.squeeze(segmentations)
125
  if ohe:
126
+ segmentations = segmentation_ohe_to_cardinal(segmentations)
127
+ lbls = set(np.unique(segmentations)) - {0} # Remove the 0 value returned by np.unique, no label
128
+ # missing_lbls = set(expected_lbls) - lbls
129
+ # if brain_study:
130
+ # segmentations += np.ones_like(segmentations) # Regionsprops neglect the label 0. But we need it, so offset all labels by 1
131
  else:
132
+ lbls = set(np.unique(segmentations)) if 0 in expected_lbls else set(np.unique(segmentations)) - {0}
133
+ missing_lbls = set(expected_lbls) - lbls
134
+
135
+ if 0 in expected_lbls:
136
+ segmentations += np.ones_like(segmentations) # Regionsprops neglects the label 0. But we need it, so offset all labels by 1
137
+
138
+ segmentations = np.squeeze(segmentations) # remove channel dimension, not needed anyway
139
 
140
  seg_props = regionprops(segmentations)
141
  centroids = np.asarray([c.centroid for c in seg_props]).astype(np.float32)
 
155
  return np.argmax(cpy, axis=-1)[..., np.newaxis]
156
 
157
 
158
+ def segmentation_cardinal_to_ohe(segmentation, labels_list: list = None):
159
  # Keep in mind that we don't handle the overlap between the segmentations!
160
+ #labels_list = np.unique(segmentation)[1:] if labels_list is None else labels_list
161
+ num_labels = len(labels_list)
162
+ expected_shape = segmentation.shape[:-1] + (num_labels,)
163
+ cpy = np.zeros(expected_shape, dtype=np.uint8)
164
+ seg_squeezed = np.squeeze(segmentation, axis=-1)
165
+ for ch, lbl in enumerate(labels_list):
166
+ cpy[seg_squeezed == lbl, ch] = 1
167
  return cpy
168
 
169
 
 
196
 
197
  return trf
198
 
199
+
200
+ class GaussianFilter:
201
+ def __init__(self, size, sigma, dim, num_channels, stride=None, batch: bool=True):
202
+ """
203
+ Gaussian filter
204
+ :param size: Kernel size
205
+ :param sigma: Sigma of the Gaussian filter.
206
+ :param dim: Data dimensionality. Must be {2, 3}.
207
+ :param num_channels: Number of channels of the image to filter.
208
+ """
209
+ self.size = size
210
+ self.dim = dim
211
+ self.sigma = float(sigma)
212
+ self.num_channels = num_channels
213
+ self.stride = size // 2 if stride is None else int(stride)
214
+ if batch:
215
+ self.stride = [1] + [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims
216
+ else:
217
+ self.stride = [self.stride] * self.dim + [1] # No support for strides in the batch and channel dims
218
+
219
+ self.convDN = getattr(tf.nn, 'conv%dd' % dim)
220
+ self.__GF = None
221
+
222
+ self.__build_gaussian_filter()
223
+
224
+ def __build_gaussian_filter(self):
225
+ range_1d = tf.range(-(self.size/2) + 1, self.size//2 + 1)
226
+ g_1d = tf.math.exp(-1.0 * tf.pow(range_1d, 2) / (2. * tf.pow(self.sigma, 2)))
227
+ g_1d_expanded = tf.expand_dims(g_1d, -1)
228
+ iterator = tf.constant(1)
229
+ self.__GF = tf.while_loop(lambda iterator, g_1d: tf.less(iterator, self.dim),
230
+ lambda iterator, g_1d: (iterator + 1, tf.expand_dims(g_1d, -1) * tf.transpose(g_1d_expanded)),
231
+ [iterator, g_1d],
232
+ [iterator.get_shape(), tf.TensorShape(None)], # Shape invariants
233
+ back_prop=False
234
+ )[-1]
235
+
236
+ self.__GF = tf.divide(self.__GF, tf.reduce_sum(self.__GF)) # Normalization
237
+ self.__GF = tf.reshape(self.__GF, (*[self.size]*self.dim, 1, 1)) # Add Ch_in and Ch_out for convolution
238
+ self.__GF = tf.tile(self.__GF, (*[1] * self.dim, self.num_channels, self.num_channels,))
239
+
240
+ def apply_filter(self, in_image):
241
+ return self.convDN(in_image, self.__GF, self.stride, 'SAME')
242
+
243
+ @property
244
+ def kernel(self):
245
+ return self.__GF
DeepDeformationMapRegistration/utils/operators.py CHANGED
@@ -63,3 +63,18 @@ def sample_unique(population, samples, tout=tf.int32):
63
  _, indices = tf.nn.top_k(z, samples)
64
  ret_val = tf.gather(population, indices)
65
  return tf.cast(ret_val, tout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  _, indices = tf.nn.top_k(z, samples)
64
  ret_val = tf.gather(population, indices)
65
  return tf.cast(ret_val, tout)
66
+
67
+
68
+ def safe_medpy_metric(prediction, reference, nb_labels, metric_fnc, fnc_args: dict={}):
69
+ vals = list()
70
+ if 'voxelspacing' in fnc_args.keys():
71
+ diag = np.power(reference.shape[:-1] * fnc_args['voxelspacing'], 2)
72
+ else:
73
+ diag = np.power(reference.shape[:-1], 2)
74
+ diag = np.sqrt(np.sum(diag))
75
+ for l in range(nb_labels):
76
+ try:
77
+ vals.append(metric_fnc(prediction[..., l], reference[..., l], **fnc_args))
78
+ except RuntimeError:
79
+ vals.append(diag)
80
+ return vals
DeepDeformationMapRegistration/utils/visualization.py CHANGED
@@ -1,5 +1,5 @@
1
  import matplotlib
2
- #matplotlib.use('TkAgg')
3
  import matplotlib.pyplot as plt
4
  from mpl_toolkits.mplot3d import Axes3D
5
  import matplotlib.colors as mcolors
@@ -17,7 +17,7 @@ THRES = 0.9
17
 
18
  # COLOR MAPS
19
  chunks = np.linspace(0, 1, 10)
20
- cmap1 = plt.get_cmap('hsv', 4)
21
  # cmaplist = [cmap1(i) for i in range(cmap1.N)]
22
  cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
23
  cmaplist[0] = (1, 1, 1, 1.0)
@@ -34,6 +34,14 @@ cmap4 = mcolors.LinearSegmentedColormap.from_list('mycmap', colors, N=100)
34
 
35
  cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
36
 
 
 
 
 
 
 
 
 
37
 
38
  def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
39
  if dimensionality == 2:
@@ -321,7 +329,7 @@ def save_centreline_img(img, title, filename, fig=None):
321
  plt.close()
322
 
323
 
324
- def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False):
325
  if fig is not None:
326
  fig.clear()
327
  plt.figure(fig.number)
@@ -333,7 +341,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
333
  if dim == 2:
334
  ax_x = fig.add_subplot(131)
335
  ax_x.set_title('H displacement')
336
- im_x = ax_x.imshow(disp_map[..., C.H_DISP])
337
  ax_x.tick_params(axis='both',
338
  which='both',
339
  bottom=False,
@@ -344,7 +352,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
344
 
345
  ax_y = fig.add_subplot(132)
346
  ax_y.set_title('W displacement')
347
- im_y = ax_y.imshow(disp_map[..., C.W_DISP])
348
  ax_y.tick_params(axis='both',
349
  which='both',
350
  bottom=False,
@@ -371,7 +379,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
371
  ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
372
 
373
  else:
374
- c, d, s = _prepare_quiver_map(disp_map, dim=dim)
375
  im = ax.imshow(s, interpolation='none', aspect='equal')
376
  ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
377
  scale=C.QUIVER_PARAMS.arrow_scale)
@@ -386,7 +394,7 @@ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None,
386
  fig.suptitle(title)
387
  else:
388
  ax = fig.add_subplot(111, projection='3d')
389
- c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim)
390
  ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
391
  _square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
392
  fig.suptitle('Displacement map')
@@ -810,7 +818,12 @@ def plot_dataset_3d(img_sets):
810
  return fig
811
 
812
 
813
- def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batch, filename='predictions', fig=None, show=False):
 
 
 
 
 
814
  num_rows = fix_img_batch.shape[0]
815
  img_dim = len(fix_img_batch.shape) - 2
816
  img_size = fix_img_batch.shape[1:-1]
@@ -828,6 +841,10 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
828
  fix_img_batch = fix_img_batch[:, selected_slice, ...]
829
  mov_img_batch = mov_img_batch[:, selected_slice, ...]
830
  pred_img_batch = pred_img_batch[:, selected_slice, ...]
 
 
 
 
831
  disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
832
  img_size = fix_img_batch.shape[1:-1]
833
  elif img_dim != 2:
@@ -836,16 +853,24 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
836
  for row in range(num_rows):
837
  fix_img = fix_img_batch[row, :, :, 0].transpose()
838
  mov_img = mov_img_batch[row, :, :, 0].transpose()
839
- disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
840
  pred_img = pred_img_batch[row, :, :, 0].transpose()
841
- ax[row, 0].imshow(fix_img, origin='lower')
 
 
 
 
 
 
 
842
  ax[row, 0].tick_params(axis='both',
843
  which='both',
844
  bottom=False,
845
  left=False,
846
  labelleft=False,
847
  labelbottom=False)
848
- ax[row, 1].imshow(mov_img, origin='lower')
 
 
849
  ax[row, 1].tick_params(axis='both',
850
  which='both',
851
  bottom=False,
@@ -853,7 +878,7 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
853
  labelleft=False,
854
  labelbottom=False)
855
 
856
- c, d, s = _prepare_quiver_map(disp_map, spc=5)
857
  cx, cy = c
858
  dx, dy = d
859
  disp_map_color = _prepare_colormap(disp_map)
@@ -866,7 +891,9 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
866
  labelleft=False,
867
  labelbottom=False)
868
 
869
- ax[row, 3].imshow(mov_img, origin='lower')
 
 
870
  ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
871
  ax[row, 3].tick_params(axis='both',
872
  which='both',
@@ -875,7 +902,9 @@ def plot_predictions(fix_img_batch, mov_img_batch, disp_map_batch, pred_img_batc
875
  labelleft=False,
876
  labelbottom=False)
877
 
878
- ax[row, 4].imshow(pred_img, origin='lower')
 
 
879
  ax[row, 4].tick_params(axis='both',
880
  which='both',
881
  bottom=False,
 
1
  import matplotlib
2
+ matplotlib.use('WebAgg')
3
  import matplotlib.pyplot as plt
4
  from mpl_toolkits.mplot3d import Axes3D
5
  import matplotlib.colors as mcolors
 
17
 
18
  # COLOR MAPS
19
  chunks = np.linspace(0, 1, 10)
20
+ cmap1 = plt.get_cmap('hsv', 30)
21
  # cmaplist = [cmap1(i) for i in range(cmap1.N)]
22
  cmaplist = [(1, 1, 1, 1), (0, 0, 1, 1), (230 / 255, 97 / 255, 1 / 255, 1), (128 / 255, 0 / 255, 32 / 255, 1)]
23
  cmaplist[0] = (1, 1, 1, 1.0)
 
34
 
35
  cmap_bin = cm.get_cmap('viridis', 3) # viridis is the default colormap
36
 
37
+ cmap_segs = np.asarray([mcolors.to_rgba(mcolors.CSS4_COLORS[c], 1) for c in mcolors.CSS4_COLORS.keys()])
38
+ cmap_segs.sort()
39
+ # rnd_idxs = [30, 17, 72, 90, 74, 39, 120, 63, 52, 79, 140, 68, 131, 109, 57, 49, 11, 132, 29, 46, 51, 26, 53, 7, 89, 47, 43, 121, 31, 28, 106, 92, 130, 117, 91, 118, 61, 5, 80, 93, 58, 133, 14, 98, 116, 76, 113, 111, 136, 142, 95, 122, 86, 77, 36, 97, 141, 115, 18, 81, 88, 87, 44, 146, 103, 67, 147, 48, 42, 83, 128, 65, 139, 69, 27, 135, 94, 134, 50, 19, 114, 0, 96, 10, 138, 75, 13, 12, 102, 32, 66, 16, 8, 73, 85, 145, 54, 37, 70, 143]
40
+ # cmap_segs = cmap_segs[rnd_idxs]
41
+ np.random.shuffle(cmap_segs)
42
+ cmap_segs[0, -1] = 0
43
+ cmap_segs = mcolors.ListedColormap(cmap_segs)
44
+
45
 
46
  def view_centerline_sample(sample: np.ndarray, dimensionality: int, ax=None, c=None, name=None):
47
  if dimensionality == 2:
 
329
  plt.close()
330
 
331
 
332
+ def save_disp_map_img(disp_map, title, filename, affine_transf=False, fig=None, show=False, step=1):
333
  if fig is not None:
334
  fig.clear()
335
  plt.figure(fig.number)
 
341
  if dim == 2:
342
  ax_x = fig.add_subplot(131)
343
  ax_x.set_title('H displacement')
344
+ im_x = ax_x.imshow(disp_map[..., ::step, ::step, C.H_DISP])
345
  ax_x.tick_params(axis='both',
346
  which='both',
347
  bottom=False,
 
352
 
353
  ax_y = fig.add_subplot(132)
354
  ax_y.set_title('W displacement')
355
+ im_y = ax_y.imshow(disp_map[..., ::step, ::step, C.W_DISP])
356
  ax_y.tick_params(axis='both',
357
  which='both',
358
  bottom=False,
 
379
  ax.text(i, j, transf_mat[i, j], ha="center", va="center", color="b")
380
 
381
  else:
382
+ c, d, s = _prepare_quiver_map(disp_map, dim=dim, spc=step)
383
  im = ax.imshow(s, interpolation='none', aspect='equal')
384
  ax.quiver(c[C.H_DISP], c[C.W_DISP], d[C.H_DISP], d[C.W_DISP],
385
  scale=C.QUIVER_PARAMS.arrow_scale)
 
394
  fig.suptitle(title)
395
  else:
396
  ax = fig.add_subplot(111, projection='3d')
397
+ c, d, s = _prepare_quiver_map(disp_map[0, ...], dim=dim, spc=step)
398
  ax.quiver(c[C.H_DISP], c[C.W_DISP], c[C.D_DISP], d[C.H_DISP], d[C.W_DISP], d[C.D_DISP])
399
  _square_3d_plot(np.arange(0, dim_h-1), np.arange(0, dim_w-1), np.arange(0, dim_d-1), ax)
400
  fig.suptitle('Displacement map')
 
818
  return fig
819
 
820
 
821
+ def plot_predictions(img_batches, disp_map_batch, seg_batches=None, step=1, filename='predictions', fig=None, show=False):
822
+ fix_img_batch, mov_img_batch, pred_img_batch = img_batches
823
+ if seg_batches != None:
824
+ fix_seg_batch, mov_seg_batch, pred_seg_batch = seg_batches
825
+ else:
826
+ fix_seg_batch = mov_seg_batch = pred_seg_batch = None
827
  num_rows = fix_img_batch.shape[0]
828
  img_dim = len(fix_img_batch.shape) - 2
829
  img_size = fix_img_batch.shape[1:-1]
 
841
  fix_img_batch = fix_img_batch[:, selected_slice, ...]
842
  mov_img_batch = mov_img_batch[:, selected_slice, ...]
843
  pred_img_batch = pred_img_batch[:, selected_slice, ...]
844
+ if seg_batches != None:
845
+ fix_seg_batch = fix_seg_batch[:, selected_slice, ...]
846
+ mov_seg_batch = mov_seg_batch[:, selected_slice, ...]
847
+ pred_seg_batch = pred_seg_batch[:, selected_slice, ...]
848
  disp_map_batch = disp_map_batch[:, selected_slice, ..., 1:] # Only the sagittal and longitudinal axes
849
  img_size = fix_img_batch.shape[1:-1]
850
  elif img_dim != 2:
 
853
  for row in range(num_rows):
854
  fix_img = fix_img_batch[row, :, :, 0].transpose()
855
  mov_img = mov_img_batch[row, :, :, 0].transpose()
 
856
  pred_img = pred_img_batch[row, :, :, 0].transpose()
857
+ if seg_batches != None:
858
+ fix_seg = fix_seg_batch[row, :, :, 0].transpose()
859
+ mov_seg= mov_seg_batch[row, :, :, 0].transpose()
860
+ pred_seg = pred_seg_batch[row, :, :, 0].transpose()
861
+ disp_map = disp_map_batch[row, :, :, :].transpose((1, 0, 2))
862
+ ax[row, 0].imshow(fix_img, origin='lower', cmap='gray')
863
+ if seg_batches != None:
864
+ ax[row, 0].imshow(fix_seg, origin='lower', cmap=cmap_segs)
865
  ax[row, 0].tick_params(axis='both',
866
  which='both',
867
  bottom=False,
868
  left=False,
869
  labelleft=False,
870
  labelbottom=False)
871
+ ax[row, 1].imshow(mov_img, origin='lower', cmap='gray')
872
+ if seg_batches != None:
873
+ ax[row, 1].imshow(mov_seg, origin='lower', cmap=cmap_segs)
874
  ax[row, 1].tick_params(axis='both',
875
  which='both',
876
  bottom=False,
 
878
  labelleft=False,
879
  labelbottom=False)
880
 
881
+ c, d, s = _prepare_quiver_map(disp_map, spc=step)
882
  cx, cy = c
883
  dx, dy = d
884
  disp_map_color = _prepare_colormap(disp_map)
 
891
  labelleft=False,
892
  labelbottom=False)
893
 
894
+ ax[row, 3].imshow(mov_img, origin='lower', cmap='gray')
895
+ if seg_batches != None:
896
+ ax[row, 3].imshow(mov_seg, origin='lower', cmap=cmap_segs)
897
  ax[row, 3].quiver(cx, cy, dx, dy, units='dots', scale=1, color='w')
898
  ax[row, 3].tick_params(axis='both',
899
  which='both',
 
902
  labelleft=False,
903
  labelbottom=False)
904
 
905
+ ax[row, 4].imshow(pred_img, origin='lower', cmap='gray')
906
+ if seg_batches != None:
907
+ ax[row, 4].imshow(pred_seg, origin='lower', cmap=cmap_segs)
908
  ax[row, 4].tick_params(axis='both',
909
  which='both',
910
  bottom=False,
requirements.txt CHANGED
@@ -27,20 +27,23 @@ et-xmlfile==1.0.1
27
  fastrlock==0.6
28
  flatbuffers==1.12
29
  future==0.18.2
30
- gast==0.4.0
31
  google-auth==1.35.0
32
  google-auth-oauthlib==0.4.6
33
  google-pasta==0.2.0
34
  googleapis-common-protos==1.53.0
35
  grpcio==1.40.0
36
- h5py==3.1.0
37
  idna==2.10
38
  imageio==2.9.0
39
  importlib-metadata==3.4.0
40
  importlib-resources==5.2.2
 
41
  ipykernel==5.5.3
42
  ipython==7.16.1
43
  ipython-genutils==0.2.0
 
 
44
  ipywidgets==7.6.3
45
  jedi==0.18.0
46
  Jinja2==2.11.3
@@ -84,7 +87,7 @@ patsy==0.5.1
84
  pexpect==4.8.0
85
  pickleshare==0.7.5
86
  Pillow==8.1.0
87
- pkg-resources==0.0.0
88
  plotly==4.14.3
89
  plyfile==0.7.3
90
  probreg==0.3.1
@@ -107,6 +110,7 @@ pyrsistent==0.17.3
107
  pystrum==0.1
108
  python-dateutil==2.8.1
109
  python-utils==2.5.6
 
110
  pytz==2021.1
111
  PyWavelets==1.1.1
112
  PyYAML==5.4.1
@@ -123,13 +127,13 @@ SimpleITK==2.0.2
123
  six==1.15.0
124
  sklearn==0.0
125
  statsmodels==0.12.2
126
- tensorboard==1.13.1
 
127
  tensorboard-data-server==0.6.1
128
- tensorboard-plugin-wit==1.8.0
129
  tensorflow-addons==0.14.0
130
  tensorflow-datasets==4.4.0
131
- tensorflow-estimator==1.13.0
132
- tensorflow-gpu==1.13.1
133
  tensorflow-metadata==1.2.0
134
  termcolor==1.1.0
135
  terminado==0.9.4
@@ -141,6 +145,7 @@ tikzplotlib==0.9.7
141
  tornado==6.1
142
  tqdm==4.56.0
143
  traitlets==4.3.3
 
144
  transformations==2020.1.1
145
  trimesh==3.9.29
146
  typeguard==2.12.1
 
27
  fastrlock==0.6
28
  flatbuffers==1.12
29
  future==0.18.2
30
+ gast==0.2.2
31
  google-auth==1.35.0
32
  google-auth-oauthlib==0.4.6
33
  google-pasta==0.2.0
34
  googleapis-common-protos==1.53.0
35
  grpcio==1.40.0
36
+ h5py==2.10.0
37
  idna==2.10
38
  imageio==2.9.0
39
  importlib-metadata==3.4.0
40
  importlib-resources==5.2.2
41
+ ipydatawidgets==4.2.0
42
  ipykernel==5.5.3
43
  ipython==7.16.1
44
  ipython-genutils==0.2.0
45
+ ipyvolume==0.5.2
46
+ ipywebrtc==0.6.0
47
  ipywidgets==7.6.3
48
  jedi==0.18.0
49
  Jinja2==2.11.3
 
87
  pexpect==4.8.0
88
  pickleshare==0.7.5
89
  Pillow==8.1.0
90
+ pkg_resources==0.0.0
91
  plotly==4.14.3
92
  plyfile==0.7.3
93
  probreg==0.3.1
 
110
  pystrum==0.1
111
  python-dateutil==2.8.1
112
  python-utils==2.5.6
113
+ pythreejs==2.3.0
114
  pytz==2021.1
115
  PyWavelets==1.1.1
116
  PyYAML==5.4.1
 
127
  six==1.15.0
128
  sklearn==0.0
129
  statsmodels==0.12.2
130
+ tabulate==0.8.9
131
+ tensorboard==1.14.0
132
  tensorboard-data-server==0.6.1
 
133
  tensorflow-addons==0.14.0
134
  tensorflow-datasets==4.4.0
135
+ tensorflow-estimator==1.14.0
136
+ tensorflow-gpu==1.14.0
137
  tensorflow-metadata==1.2.0
138
  termcolor==1.1.0
139
  terminado==0.9.4
 
145
  tornado==6.1
146
  tqdm==4.56.0
147
  traitlets==4.3.3
148
+ traittypes==0.2.1
149
  transformations==2020.1.1
150
  trimesh==3.9.29
151
  typeguard==2.12.1