Spaces:
Runtime error
Runtime error
File size: 1,063 Bytes
153628e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from torch import Tensor
from torch.nn.functional import max_pool2d
__all__ = ["erode", "dilate"]
def erode(x: Tensor, kernel_size: int) -> Tensor:
"""Performs erosion on a given tensor
Args:
----
x: boolean tensor of shape (N, C, H, W)
kernel_size: the size of the kernel to use for erosion
Returns:
-------
the eroded tensor
"""
_pad = (kernel_size - 1) // 2
return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad)
def dilate(x: Tensor, kernel_size: int) -> Tensor:
"""Performs dilation on a given tensor
Args:
----
x: boolean tensor of shape (N, C, H, W)
kernel_size: the size of the kernel to use for dilation
Returns:
-------
the dilated tensor
"""
_pad = (kernel_size - 1) // 2
return max_pool2d(x, kernel_size, stride=1, padding=_pad)
|