File size: 592 Bytes
d9697ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import cv2
import torch
import numpy as np
import numpy.typing as npt

from torch import Tensor
from kornia.morphology import erosion, dilation


def clean_mask_torch(mask: Tensor) -> Tensor:
    kernel = torch.ones(2, 2).to(mask.device)
    if len(mask.shape) == 2:
        mask = mask[None, None, :, :]
    if mask.dtype == torch.bool:
        mask = mask.int()
    return dilation(erosion(mask, kernel), kernel)


def clean_mask_np(mask: npt.NDArray) -> npt.NDArray:
    kernel = np.ones((2, 2), np.uint8)
    return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1)