|
import io |
|
from enum import Enum |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
from cv2 import ( |
|
BORDER_DEFAULT, |
|
MORPH_ELLIPSE, |
|
MORPH_OPEN, |
|
GaussianBlur, |
|
getStructuringElement, |
|
morphologyEx, |
|
) |
|
from PIL import Image |
|
from PIL.Image import Image as PILImage |
|
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf |
|
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml |
|
from pymatting.util.util import stack_images |
|
from scipy.ndimage import binary_erosion |
|
|
|
from .session_base import BaseSession |
|
from .session_factory import new_session |
|
|
|
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) |
|
|
|
|
|
class ReturnType(Enum): |
|
BYTES = 0 |
|
PILLOW = 1 |
|
NDARRAY = 2 |
|
|
|
|
|
def alpha_matting_cutout( |
|
img: PILImage, |
|
mask: PILImage, |
|
foreground_threshold: int, |
|
background_threshold: int, |
|
erode_structure_size: int, |
|
) -> PILImage: |
|
|
|
if img.mode == "RGBA" or img.mode == "CMYK": |
|
img = img.convert("RGB") |
|
|
|
img = np.asarray(img) |
|
mask = np.asarray(mask) |
|
|
|
is_foreground = mask > foreground_threshold |
|
is_background = mask < background_threshold |
|
|
|
structure = None |
|
if erode_structure_size > 0: |
|
structure = np.ones( |
|
(erode_structure_size, erode_structure_size), dtype=np.uint8 |
|
) |
|
|
|
is_foreground = binary_erosion(is_foreground, structure=structure) |
|
is_background = binary_erosion(is_background, structure=structure, border_value=1) |
|
|
|
trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) |
|
trimap[is_foreground] = 255 |
|
trimap[is_background] = 0 |
|
|
|
img_normalized = img / 255.0 |
|
trimap_normalized = trimap / 255.0 |
|
|
|
alpha = estimate_alpha_cf(img_normalized, trimap_normalized) |
|
foreground = estimate_foreground_ml(img_normalized, alpha) |
|
cutout = stack_images(foreground, alpha) |
|
|
|
cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) |
|
cutout = Image.fromarray(cutout) |
|
|
|
return cutout |
|
|
|
|
|
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage: |
|
empty = Image.new("RGBA", (img.size), 0) |
|
cutout = Image.composite(img, empty, mask) |
|
return cutout |
|
|
|
|
|
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage: |
|
pivot = imgs.pop(0) |
|
for im in imgs: |
|
pivot = get_concat_v(pivot, im) |
|
return pivot |
|
|
|
|
|
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage: |
|
dst = Image.new("RGBA", (img1.width, img1.height + img2.height)) |
|
dst.paste(img1, (0, 0)) |
|
dst.paste(img2, (0, img1.height)) |
|
return dst |
|
|
|
|
|
def post_process(mask: np.ndarray) -> np.ndarray: |
|
""" |
|
Post Process the mask for a smooth boundary by applying Morphological Operations |
|
Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757 |
|
args: |
|
mask: Binary Numpy Mask |
|
""" |
|
mask = morphologyEx(mask, MORPH_OPEN, kernel) |
|
mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT) |
|
mask = np.where(mask < 127, 0, 255).astype(np.uint8) |
|
return mask |
|
|
|
|
|
def remove( |
|
data: Union[bytes, PILImage, np.ndarray], |
|
alpha_matting: bool = False, |
|
alpha_matting_foreground_threshold: int = 240, |
|
alpha_matting_background_threshold: int = 10, |
|
alpha_matting_erode_size: int = 10, |
|
session: Optional[BaseSession] = None, |
|
only_mask: bool = False, |
|
post_process_mask: bool = False, |
|
) -> Union[bytes, PILImage, np.ndarray]: |
|
|
|
if isinstance(data, PILImage): |
|
return_type = ReturnType.PILLOW |
|
img = data |
|
elif isinstance(data, bytes): |
|
return_type = ReturnType.BYTES |
|
img = Image.open(io.BytesIO(data)) |
|
elif isinstance(data, np.ndarray): |
|
return_type = ReturnType.NDARRAY |
|
img = Image.fromarray(data) |
|
else: |
|
raise ValueError("Input type {} is not supported.".format(type(data))) |
|
|
|
if session is None: |
|
session = new_session("u2net") |
|
|
|
masks = session.predict(img) |
|
cutouts = [] |
|
|
|
for mask in masks: |
|
if post_process_mask: |
|
mask = Image.fromarray(post_process(np.array(mask))) |
|
|
|
if only_mask: |
|
cutout = mask |
|
|
|
elif alpha_matting: |
|
try: |
|
cutout = alpha_matting_cutout( |
|
img, |
|
mask, |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, |
|
) |
|
except ValueError: |
|
cutout = naive_cutout(img, mask) |
|
|
|
else: |
|
cutout = naive_cutout(img, mask) |
|
|
|
cutouts.append(cutout) |
|
|
|
cutout = img |
|
if len(cutouts) > 0: |
|
cutout = get_concat_v_multi(cutouts) |
|
|
|
if ReturnType.PILLOW == return_type: |
|
return cutout |
|
|
|
if ReturnType.NDARRAY == return_type: |
|
return np.asarray(cutout) |
|
|
|
bio = io.BytesIO() |
|
cutout.save(bio, "PNG") |
|
bio.seek(0) |
|
|
|
return bio.read() |
|
|