adirathor07's picture
added doctr folder
153628e
# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from torch import Tensor
from torch.nn.functional import max_pool2d
__all__ = ["erode", "dilate"]
def erode(x: Tensor, kernel_size: int) -> Tensor:
"""Performs erosion on a given tensor
Args:
----
x: boolean tensor of shape (N, C, H, W)
kernel_size: the size of the kernel to use for erosion
Returns:
-------
the eroded tensor
"""
_pad = (kernel_size - 1) // 2
return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad)
def dilate(x: Tensor, kernel_size: int) -> Tensor:
"""Performs dilation on a given tensor
Args:
----
x: boolean tensor of shape (N, C, H, W)
kernel_size: the size of the kernel to use for dilation
Returns:
-------
the dilated tensor
"""
_pad = (kernel_size - 1) // 2
return max_pool2d(x, kernel_size, stride=1, padding=_pad)