|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List, Union |
|
|
|
import numpy as np |
|
import torch |
|
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector |
|
from monai.inferers.inferer import Inferer |
|
from torch import Tensor |
|
|
|
|
|
class RetinaNetInferer(Inferer): |
|
""" |
|
RetinaNet Inferer takes RetinaNet as input |
|
|
|
Args: |
|
detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP |
|
map into boxes and classification scores. |
|
force_sliding_window: whether to force using a SlidingWindowInferer to do the inference. |
|
If False, will check the input spatial size to decide whether to simply |
|
forward the network or using SlidingWindowInferer. |
|
If True, will force using SlidingWindowInferer to do the inference. |
|
args: other optional args to be passed to detector. |
|
kwargs: other optional keyword args to be passed to detector. |
|
""" |
|
|
|
def __init__(self, detector: RetinaNetDetector, force_sliding_window: bool = False) -> None: |
|
Inferer.__init__(self) |
|
self.detector = detector |
|
self.sliding_window_size = None |
|
self.force_sliding_window = force_sliding_window |
|
if self.detector.inferer is not None: |
|
if hasattr(self.detector.inferer, "roi_size"): |
|
self.sliding_window_size = np.prod(self.detector.inferer.roi_size) |
|
|
|
def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any): |
|
"""Unified callable function API of Inferers. |
|
Args: |
|
inputs: model input data for inference. |
|
network: target detection network to execute inference. |
|
supports callable that fullfilles requirements of network in |
|
monai.apps.detection.networks.retinanet_detector.RetinaNetDetector`` |
|
args: optional args to be passed to ``network``. |
|
kwargs: optional keyword args to be passed to ``network``. |
|
""" |
|
self.detector.network = network |
|
self.detector.training = self.detector.network.training |
|
|
|
|
|
|
|
use_inferer = ( |
|
self.force_sliding_window |
|
or self.sliding_window_size is not None |
|
and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]) |
|
) |
|
|
|
return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs) |
|
|