Spaces:
Build error
Build error
import cv2 | |
import torch | |
import numpy as np | |
import numpy.typing as npt | |
from torch import Tensor | |
from kornia.morphology import erosion, dilation | |
def clean_mask_torch(mask: Tensor) -> Tensor: | |
kernel = torch.ones(2, 2).to(mask.device) | |
if len(mask.shape) == 2: | |
mask = mask[None, None, :, :] | |
if mask.dtype == torch.bool: | |
mask = mask.int() | |
return dilation(erosion(mask, kernel), kernel) | |
def clean_mask_np(mask: npt.NDArray) -> npt.NDArray: | |
kernel = np.ones((2, 2), np.uint8) | |
return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1) | |