auto-labeler / app /sam /postprocess.py
dillonlaird's picture
added postprocessing
d9697ef
raw
history blame contribute delete
592 Bytes
import cv2
import torch
import numpy as np
import numpy.typing as npt
from torch import Tensor
from kornia.morphology import erosion, dilation
def clean_mask_torch(mask: Tensor) -> Tensor:
kernel = torch.ones(2, 2).to(mask.device)
if len(mask.shape) == 2:
mask = mask[None, None, :, :]
if mask.dtype == torch.bool:
mask = mask.int()
return dilation(erosion(mask, kernel), kernel)
def clean_mask_np(mask: npt.NDArray) -> npt.NDArray:
kernel = np.ones((2, 2), np.uint8)
return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1)