soumickmj commited on
Commit
d1efa41
1 Parent(s): 3bef49f

normalisation for patch-infer fixed

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -48,8 +48,18 @@ def infer_full_vol(tensor, model):
48
 
49
  def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
50
  test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor.unsqueeze(0))) # adding channel dim while creating the TorchIO subject
 
51
  overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
52
 
 
 
 
 
 
 
 
 
 
53
  with torch.no_grad():
54
  grid_sampler = tio.inference.GridSampler(
55
  test_subject,
@@ -63,7 +73,7 @@ def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_wid
63
  for i, patches_batch in enumerate(patch_loader):
64
  st.text(f"Processing batch {i + 1} of {total_batches}...")
65
 
66
- local_batch = patches_batch['img'][tio.DATA].float()
67
  local_batch = local_batch / local_batch.max()
68
  locations = patches_batch[tio.LOCATION]
69
 
 
48
 
49
  def infer_patch_based(tensor, model, patch_size=64, stride_length=32, stride_width=32, stride_depth=16, batch_size=10, num_worker=2):
50
  test_subject = tio.Subject(img = tio.ScalarImage(tensor=tensor.unsqueeze(0))) # adding channel dim while creating the TorchIO subject
51
+
52
  overlap = np.subtract(patch_size, (stride_length, stride_width, stride_depth))
53
 
54
+ def normaliser(batch):
55
+ """
56
+ Purpose: Normalise pixel intensities of each patch using the max values in the 3D patch
57
+ :param batch: 5D array (batch_size x channel x width x depth x height)
58
+ """
59
+ for i in range(batch.shape[0]):
60
+ batch[i] = batch[i] / batch[i].max()
61
+ return batch
62
+
63
  with torch.no_grad():
64
  grid_sampler = tio.inference.GridSampler(
65
  test_subject,
 
73
  for i, patches_batch in enumerate(patch_loader):
74
  st.text(f"Processing batch {i + 1} of {total_batches}...")
75
 
76
+ local_batch = normaliser(patches_batch['img'][tio.DATA].float())
77
  local_batch = local_batch / local_batch.max()
78
  locations = patches_batch[tio.LOCATION]
79