File size: 1,063 Bytes
153628e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from torch import Tensor
from torch.nn.functional import max_pool2d

__all__ = ["erode", "dilate"]


def erode(x: Tensor, kernel_size: int) -> Tensor:
    """Performs erosion on a given tensor

    Args:
    ----
        x: boolean tensor of shape (N, C, H, W)
        kernel_size: the size of the kernel to use for erosion

    Returns:
    -------
        the eroded tensor
    """
    _pad = (kernel_size - 1) // 2

    return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad)


def dilate(x: Tensor, kernel_size: int) -> Tensor:
    """Performs dilation on a given tensor

    Args:
    ----
        x: boolean tensor of shape (N, C, H, W)
        kernel_size: the size of the kernel to use for dilation

    Returns:
    -------
        the dilated tensor
    """
    _pad = (kernel_size - 1) // 2

    return max_pool2d(x, kernel_size, stride=1, padding=_pad)