Spaces:
Runtime error
Runtime error
fullvol auto-padding added
Browse files
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
import json
|
|
|
3 |
import numpy as np
|
4 |
import nibabel as nib
|
5 |
import torch
|
|
|
6 |
import scipy.io
|
7 |
from io import BytesIO
|
8 |
from transformers import AutoModel
|
@@ -15,11 +17,32 @@ from skimage.filters import threshold_otsu
|
|
15 |
def infer_full_vol(tensor, model):
|
16 |
tensor = torch.movedim(tensor, -1, -3)
|
17 |
tensor = tensor / tensor.max()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
with torch.no_grad():
|
19 |
output = model(tensor)
|
20 |
if type(output) is tuple or type(output) is list:
|
21 |
output = output[0]
|
22 |
output = torch.sigmoid(output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
output = torch.movedim(output, -3, -1).type(tensor.type())
|
24 |
return output.squeeze().detach().cpu().numpy()
|
25 |
|
|
|
1 |
import streamlit as st
|
2 |
import json
|
3 |
+
import math
|
4 |
import numpy as np
|
5 |
import nibabel as nib
|
6 |
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
import scipy.io
|
9 |
from io import BytesIO
|
10 |
from transformers import AutoModel
|
|
|
17 |
def infer_full_vol(tensor, model):
|
18 |
tensor = torch.movedim(tensor, -1, -3)
|
19 |
tensor = tensor / tensor.max()
|
20 |
+
|
21 |
+
sizes = tensor.shape[-3:]
|
22 |
+
new_sizes = [math.ceil(s / 16) * 16 for s in sizes]
|
23 |
+
total_pads = [new_size - s for s, new_size in zip(sizes, new_sizes)]
|
24 |
+
pad_before = [pad // 2 for pad in total_pads]
|
25 |
+
pad_after = [pad - pad_before[i] for i, pad in enumerate(total_pads)]
|
26 |
+
padding = []
|
27 |
+
for i in reversed(range(len(pad_before))):
|
28 |
+
padding.extend([pad_before[i], pad_after[i]])
|
29 |
+
tensor = F.pad(tensor, padding)
|
30 |
+
|
31 |
with torch.no_grad():
|
32 |
output = model(tensor)
|
33 |
if type(output) is tuple or type(output) is list:
|
34 |
output = output[0]
|
35 |
output = torch.sigmoid(output)
|
36 |
+
|
37 |
+
slices = [slice(None)] * output.dim()
|
38 |
+
for i in range(len(pad_before)):
|
39 |
+
dim = -3 + i
|
40 |
+
start = pad_before[i]
|
41 |
+
size = sizes[i]
|
42 |
+
end = start + size
|
43 |
+
slices[dim] = slice(start, end)
|
44 |
+
output = output[tuple(slices)]
|
45 |
+
|
46 |
output = torch.movedim(output, -3, -1).type(tensor.type())
|
47 |
return output.squeeze().detach().cpu().numpy()
|
48 |
|