import PIL import numpy as np from PIL import Image, ImageDraw from PIL.Image import Image as PILImage from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml from pymatting.util.util import stack_images from rembg.bg import post_process, naive_cutout, apply_background_color from scipy.ndimage import binary_erosion def alpha_matting_cutout(img: PILImage, trimap: np.ndarray) -> PILImage: if img.mode == "RGBA" or img.mode == "CMYK": img = img.convert("RGB") img = np.asarray(img) img_normalized = img / 255.0 trimap_normalized = trimap / 255.0 alpha = estimate_alpha_cf(img_normalized, trimap_normalized) foreground = estimate_foreground_ml(img_normalized, alpha) cutout = stack_images(foreground, alpha) cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) return Image.fromarray(cutout) def generate_trimap( mask: PILImage, foreground_threshold: int, background_threshold: int, erode_structure_size: int, ) -> np.ndarray: mask = np.asarray(mask) is_foreground = mask > foreground_threshold is_background = mask < background_threshold structure = None if erode_structure_size > 0: structure = np.ones( (erode_structure_size, erode_structure_size), dtype=np.uint8 ) is_foreground = binary_erosion(is_foreground, structure=structure) is_background = binary_erosion(is_background, structure=structure, border_value=1) trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128) trimap[is_foreground] = 255 trimap[is_background] = 0 return trimap def get_background_dominant_color(img: PILImage, mask: PILImage) -> tuple: negative_img = img.copy() negative_mask = PIL.ImageOps.invert(mask) negative_img.putalpha(negative_mask) negative_img = negative_img.resize((1, 1)) r, g, b, a = negative_img.getpixel((0, 0)) return r, g, b, 255 def remove(session, img: PILImage, smoot: bool, matting: tuple, color) -> (PILImage, PILImage): mask = session.predict(img)[0] if smoot: mask = PIL.Image.fromarray(post_process(np.array(mask))) fg_t, bg_t, erode = matting if fg_t > 0 or bg_t > 0 or erode > 0: mask = generate_trimap(mask, *matting) try: cutout = alpha_matting_cutout(img, mask) mask = PIL.Image.fromarray(mask) except ValueError as err: raise err else: cutout = naive_cutout(img, mask) if color is True: color = get_background_dominant_color(img, mask) cutout = apply_background_color(cutout, color) elif isinstance(color, str): cutout = apply_background_color(cutout, parse_rgba(color)) return cutout, mask def parse_rgba(color_str): color_values = color_str[5:-1].split(',') r = int(float(color_values[0].strip())) g = int(float(color_values[1].strip())) b = int(float(color_values[2].strip())) a = int(float(color_values[3].strip()) * 255) # Alpha scaled to 0-255 return r, g, b, a def text_size(draw, text): _, _, width, height = draw.textbbox((0, 0), text=text) return width, height def make_label(text, width=600, height=200, color="black") -> PILImage: image = Image.new("RGB", (width, height), color) draw = ImageDraw.Draw(image) text_width, text_height = text_size(draw, text) draw.text(((width-text_width)/2, height/2), text) return image