File size: 3,129 Bytes
7226a40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Union

import numpy as np
import torch
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.inferers.inferer import Inferer
from torch import Tensor


class RetinaNetInferer(Inferer):
    """
    RetinaNet Inferer takes RetinaNet as input

    Args:
        detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
            map into boxes and classification scores.
        force_sliding_window: whether to force using a SlidingWindowInferer to do the inference.
                If False, will check the input spatial size to decide whether to simply
                forward the network or using SlidingWindowInferer.
                If True, will force using SlidingWindowInferer to do the inference.
        args: other optional args to be passed to detector.
        kwargs: other optional keyword args to be passed to detector.
    """

    def __init__(self, detector: RetinaNetDetector, force_sliding_window: bool = False) -> None:
        Inferer.__init__(self)
        self.detector = detector
        self.sliding_window_size = None
        self.force_sliding_window = force_sliding_window
        if self.detector.inferer is not None:
            if hasattr(self.detector.inferer, "roi_size"):
                self.sliding_window_size = np.prod(self.detector.inferer.roi_size)

    def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
        """Unified callable function API of Inferers.
        Args:
            inputs: model input data for inference.
            network: target detection network to execute inference.
                supports callable that fullfilles requirements of network in
                monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.
        """
        self.detector.network = network
        self.detector.training = self.detector.network.training

        # if image smaller than sliding window roi size, no need to use sliding window inferer
        # use sliding window inferer only when image is large
        use_inferer = (
            self.force_sliding_window
            or self.sliding_window_size is not None
            and not all([data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs])
        )

        return self.detector(inputs, *args, use_inferer=use_inferer, **kwargs)