soumickmj commited on
Commit
e03497c
1 Parent(s): 5192ddb

fullvol auto-padding added

Browse files
Files changed (1) hide show
  1. app.py +23 -0
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import streamlit as st
2
  import json
 
3
  import numpy as np
4
  import nibabel as nib
5
  import torch
 
6
  import scipy.io
7
  from io import BytesIO
8
  from transformers import AutoModel
@@ -15,11 +17,32 @@ from skimage.filters import threshold_otsu
15
  def infer_full_vol(tensor, model):
16
  tensor = torch.movedim(tensor, -1, -3)
17
  tensor = tensor / tensor.max()
 
 
 
 
 
 
 
 
 
 
 
18
  with torch.no_grad():
19
  output = model(tensor)
20
  if type(output) is tuple or type(output) is list:
21
  output = output[0]
22
  output = torch.sigmoid(output)
 
 
 
 
 
 
 
 
 
 
23
  output = torch.movedim(output, -3, -1).type(tensor.type())
24
  return output.squeeze().detach().cpu().numpy()
25
 
 
1
  import streamlit as st
2
  import json
3
+ import math
4
  import numpy as np
5
  import nibabel as nib
6
  import torch
7
+ import torch.nn.functional as F
8
  import scipy.io
9
  from io import BytesIO
10
  from transformers import AutoModel
 
17
  def infer_full_vol(tensor, model):
18
  tensor = torch.movedim(tensor, -1, -3)
19
  tensor = tensor / tensor.max()
20
+
21
+ sizes = tensor.shape[-3:]
22
+ new_sizes = [math.ceil(s / 16) * 16 for s in sizes]
23
+ total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)]
24
+ pad_before = [pad // 2 for pad in total_pads]
25
+ pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)]
26
+ padding = []
27
+ for i in reversed(range(len(pad_before))):
28
+ padding.extend([pad_before[i], pad_after[i]])
29
+ tensor = F.pad(tensor, padding)
30
+
31
  with torch.no_grad():
32
  output = model(tensor)
33
  if type(output) is tuple or type(output) is list:
34
  output = output[0]
35
  output = torch.sigmoid(output)
36
+
37
+ slices = [slice(None)] * output.dim()
38
+ for i in range(len(pad_before)):
39
+ dim = -3 + i
40
+ start = pad_before[i]
41
+ size = sizes[i]
42
+ end = start + size
43
+ slices[dim] = slice(start, end)
44
+ output = output[tuple(slices)]
45
+
46
  output = torch.movedim(output, -3, -1).type(tensor.type())
47
  return output.squeeze().detach().cpu().numpy()
48