import cv2 import torch import numpy as np import numpy.typing as npt from torch import Tensor from kornia.morphology import erosion, dilation def clean_mask_torch(mask: Tensor) -> Tensor: kernel = torch.ones(2, 2).to(mask.device) if len(mask.shape) == 2: mask = mask[None, None, :, :] if mask.dtype == torch.bool: mask = mask.int() return dilation(erosion(mask, kernel), kernel) def clean_mask_np(mask: npt.NDArray) -> npt.NDArray: kernel = np.ones((2, 2), np.uint8) return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1)