|
from collections.abc import Callable, Sequence |
|
from typing import Any, Iterable |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from monai.data.meta_tensor import MetaTensor |
|
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
|
from monai.inferers.utils import _create_buffered_slices, _compute_coords, _get_scan_interval, _flatten_struct, _pack_struct |
|
from monai.utils import ( |
|
BlendMode, |
|
PytorchPadMode, |
|
convert_data_type, |
|
convert_to_dst_type, |
|
ensure_tuple, |
|
ensure_tuple_rep, |
|
fall_back_tuple, |
|
look_up_option, |
|
optional_import, |
|
pytorch_after, |
|
) |
|
from tqdm import tqdm |
|
|
|
|
|
def sliding_window_inference( |
|
inputs: torch.Tensor | MetaTensor, |
|
roi_size: Sequence[int] | int, |
|
sw_batch_size: int, |
|
predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], |
|
overlap: Sequence[float] | float = 0.25, |
|
mode: BlendMode | str = BlendMode.CONSTANT, |
|
sigma_scale: Sequence[float] | float = 0.125, |
|
padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, |
|
cval: float = 0.0, |
|
sw_device: torch.device | str | None = None, |
|
device: torch.device | str | None = None, |
|
progress: bool = False, |
|
roi_weight_map: torch.Tensor | None = None, |
|
process_fn: Callable | None = None, |
|
buffer_steps: int | None = None, |
|
buffer_dim: int = -1, |
|
with_coord: bool = False, |
|
discard_second_output: bool = False, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: |
|
""" |
|
Sliding window inference on `inputs` with `predictor`. |
|
|
|
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. |
|
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. |
|
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes |
|
could be ([128,64,256], [64,32,128]). |
|
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still |
|
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters |
|
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). |
|
|
|
When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
|
To maintain the same spatial sizes, the output image will be cropped to the original input size. |
|
|
|
Args: |
|
inputs: input image to be processed (assuming NCHW[D]) |
|
roi_size: the spatial window size for inferences. |
|
When its components have None or non-positives, the corresponding inputs dimension will be used. |
|
if the components of the `roi_size` are non-positive values, the transform will use the |
|
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
|
to `(32, 64)` if the second spatial dimension size of img is `64`. |
|
sw_batch_size: the batch size to run window slices. |
|
predictor: given input tensor ``patch_data`` in shape NCHW[D], |
|
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary |
|
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; |
|
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, |
|
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), |
|
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). |
|
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen |
|
to ensure the scaled output ROI sizes are still integers. |
|
If the `predictor`'s input and output spatial sizes are different, |
|
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. |
|
overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``. |
|
mode: {``"constant"``, ``"gaussian"``} |
|
How to blend output of overlapping windows. Defaults to ``"constant"``. |
|
|
|
- ``"constant``": gives equal weight to all predictions. |
|
- ``"gaussian``": gives less weight to predictions on edges of windows. |
|
|
|
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
|
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
|
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
|
spatial dimensions. |
|
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
|
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
|
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
|
cval: fill value for 'constant' padding mode. Default: 0 |
|
sw_device: device for the window data. |
|
By default the device (and accordingly the memory) of the `inputs` is used. |
|
Normally `sw_device` should be consistent with the device where `predictor` is defined. |
|
device: device for the stitched output prediction. |
|
By default the device (and accordingly the memory) of the `inputs` is used. If for example |
|
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
|
`inputs` and `roi_size`. Output is on the `device`. |
|
progress: whether to print a `tqdm` progress bar. |
|
roi_weight_map: pre-computed (non-negative) weight map for each ROI. |
|
If not given, and ``mode`` is not `constant`, this map will be computed on the fly. |
|
process_fn: process inference output and adjust the importance map per window |
|
buffer_steps: the number of sliding window iterations along the ``buffer_dim`` |
|
to be buffered on ``sw_device`` before writing to ``device``. |
|
(Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.) |
|
default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, |
|
(i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. |
|
buffer_dim: the spatial dimension along which the buffers are created. |
|
0 indicates the first spatial dimension. Default is -1, the last spatial dimension. |
|
with_coord: whether to pass the window coordinates to ``predictor``. Default is False. |
|
If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``. |
|
args: optional args to be passed to ``predictor``. |
|
kwargs: optional keyword args to be passed to ``predictor``. |
|
|
|
Note: |
|
- input must be channel-first and have a batch dim, supports N-D sliding window. |
|
|
|
""" |
|
buffered = buffer_steps is not None and buffer_steps > 0 |
|
num_spatial_dims = len(inputs.shape) - 2 |
|
if buffered: |
|
if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: |
|
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") |
|
if buffer_dim < 0: |
|
buffer_dim += num_spatial_dims |
|
overlap = ensure_tuple_rep(overlap, num_spatial_dims) |
|
for o in overlap: |
|
if o < 0 or o >= 1: |
|
raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") |
|
compute_dtype = inputs.dtype |
|
|
|
|
|
|
|
batch_size, _, *image_size_ = inputs.shape |
|
device = device or inputs.device |
|
sw_device = sw_device or inputs.device |
|
|
|
temp_meta = None |
|
if isinstance(inputs, MetaTensor): |
|
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False) |
|
inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0] |
|
roi_size = fall_back_tuple(roi_size, image_size_) |
|
|
|
|
|
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
|
pad_size = [] |
|
for k in range(len(inputs.shape) - 1, 1, -1): |
|
diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
|
half = diff // 2 |
|
pad_size.extend([half, diff - half]) |
|
if any(pad_size): |
|
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) |
|
|
|
|
|
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered) |
|
|
|
num_win = len(slices) |
|
total_slices = num_win * batch_size |
|
windows_range: Iterable |
|
if not buffered: |
|
non_blocking = False |
|
windows_range = range(0, total_slices, sw_batch_size) |
|
else: |
|
slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( |
|
slices, batch_size, sw_batch_size, buffer_dim, buffer_steps |
|
) |
|
non_blocking, _ss = torch.cuda.is_available(), -1 |
|
for x in b_slices[:n_per_batch]: |
|
if x[1] < _ss: |
|
non_blocking = False |
|
break |
|
_ss = x[2] |
|
|
|
|
|
valid_patch_size = get_valid_patch_size(image_size, roi_size) |
|
if valid_patch_size == roi_size and (roi_weight_map is not None): |
|
importance_map_ = roi_weight_map |
|
else: |
|
try: |
|
valid_p_size = ensure_tuple(valid_patch_size) |
|
importance_map_ = compute_importance_map( |
|
valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype |
|
) |
|
if len(importance_map_.shape) == num_spatial_dims and not process_fn: |
|
importance_map_ = importance_map_[None, None] |
|
except Exception as e: |
|
raise RuntimeError( |
|
f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" |
|
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." |
|
) from e |
|
importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0] |
|
|
|
|
|
output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 |
|
|
|
for slice_g in tqdm(windows_range) if progress else windows_range: |
|
slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices)) |
|
unravel_slice = [ |
|
[slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) |
|
for idx in slice_range |
|
] |
|
if sw_batch_size > 1: |
|
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) |
|
else: |
|
win_data = inputs[unravel_slice[0]].to(sw_device) |
|
if with_coord: |
|
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) |
|
if discard_second_output and seg_prob_out is not None: seg_prob_out = seg_prob_out[0] |
|
else: |
|
seg_prob_out = predictor(win_data, *args, **kwargs) |
|
if discard_second_output and seg_prob_out is not None: seg_prob_out = seg_prob_out[0] |
|
|
|
|
|
dict_keys, seg_tuple = _flatten_struct(seg_prob_out) |
|
if process_fn: |
|
seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_) |
|
else: |
|
w_t = importance_map_ |
|
if len(w_t.shape) == num_spatial_dims: |
|
w_t = w_t[None, None] |
|
w_t = w_t.to(dtype=compute_dtype, device=sw_device) |
|
if buffered: |
|
c_start, c_end = b_slices[b_s][1:] |
|
if not sw_device_buffer: |
|
k = seg_tuple[0].shape[1] |
|
sp_size = list(image_size) |
|
sp_size[buffer_dim] = c_end - c_start |
|
sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] |
|
for p, s in zip(seg_tuple[0], unravel_slice): |
|
offset = s[buffer_dim + 2].start - c_start |
|
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) |
|
s[0] = slice(0, 1) |
|
sw_device_buffer[0][s] += p * w_t |
|
b_i += len(unravel_slice) |
|
if b_i < b_slices[b_s][0]: |
|
continue |
|
else: |
|
sw_device_buffer = list(seg_tuple) |
|
|
|
for ss in range(len(sw_device_buffer)): |
|
b_shape = sw_device_buffer[ss].shape |
|
seg_chns, seg_shape = b_shape[1], b_shape[2:] |
|
z_scale = None |
|
if not buffered and seg_shape != roi_size: |
|
z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] |
|
w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode) |
|
if len(output_image_list) <= ss: |
|
output_shape = [batch_size, seg_chns] |
|
output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) |
|
|
|
new_tensor: Callable = torch.empty if non_blocking else torch.zeros |
|
output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device)) |
|
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) |
|
w_t_ = w_t.to(device) |
|
for __s in slices: |
|
if z_scale is not None: |
|
__s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)) |
|
count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_ |
|
if buffered: |
|
o_slice = [slice(None)] * len(inputs.shape) |
|
o_slice[buffer_dim + 2] = slice(c_start, c_end) |
|
img_b = b_s // n_per_batch |
|
o_slice[0] = slice(img_b, img_b + 1) |
|
if non_blocking: |
|
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) |
|
else: |
|
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) |
|
else: |
|
sw_device_buffer[ss] *= w_t |
|
sw_device_buffer[ss] = sw_device_buffer[ss].to(device) |
|
_compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss]) |
|
sw_device_buffer = [] |
|
if buffered: |
|
b_s += 1 |
|
|
|
if non_blocking: |
|
torch.cuda.current_stream().synchronize() |
|
|
|
|
|
for ss in range(len(output_image_list)): |
|
output_image_list[ss] /= count_map_list.pop(0) |
|
|
|
|
|
if any(pad_size): |
|
for ss, output_i in enumerate(output_image_list): |
|
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] |
|
final_slicing: list[slice] = [] |
|
for sp in range(num_spatial_dims): |
|
si = num_spatial_dims - sp - 1 |
|
slice_dim = slice( |
|
int(round(pad_size[sp * 2] * zoom_scale[si])), |
|
int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])), |
|
) |
|
final_slicing.insert(0, slice_dim) |
|
output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)] |
|
|
|
final_output = _pack_struct(output_image_list, dict_keys) |
|
if temp_meta is not None: |
|
final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0] |
|
else: |
|
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] |
|
|
|
return final_output |
|
|
|
|
|
def sw_inference(model, input, roi_size, autocast_on, discard_second_output, overlap=0.8): |
|
def _compute(input): |
|
return sliding_window_inference( |
|
inputs=input, |
|
roi_size=roi_size, |
|
sw_batch_size=1, |
|
predictor=model, |
|
overlap=overlap, |
|
progress=False, |
|
mode="constant", |
|
discard_second_output=discard_second_output |
|
) |
|
|
|
if autocast_on: |
|
with torch.cuda.amp.autocast(): |
|
return _compute(input) |
|
else: |
|
return _compute(input) |
|
|
|
|
|
|
|
|