matting_mask / app.py
haor's picture
Upload app.py
850927f verified
raw
history blame
4.13 kB
import gradio as gr
from PIL import Image
import numpy as np
import cv2
import os
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class ImageMattingPipeline:
def __init__(self, model_dir: str, input_name: str = 'input_image:0', output_name: str = 'output_png:0'):
model_path = os.path.join(model_dir, 'tf_graph.pb')
if not os.path.exists(model_path):
raise FileNotFoundError("Model file not found at {}".format(model_path))
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
self.graph = tf.Graph()
with self.graph.as_default():
self._session = tf.Session(config=config)
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
self.output = self._session.graph.get_tensor_by_name(output_name)
self.input_name = input_name
def preprocess(self, input_image):
img = np.array(input_image)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img = img.astype(float)
return {'img': img}
def forward(self, input, output_mask=False, alpha_threshold=128):
with self.graph.as_default(), self._session.as_default():
feed_dict = {self.input_name: input['img']}
output_img = self._session.run(self.output, feed_dict=feed_dict)
result = {'output_img': output_img}
if output_mask:
alpha_channel = output_img[:, :, 3]
mask = np.zeros(alpha_channel.shape, dtype=np.uint8)
mask[alpha_channel >= alpha_threshold] = 255
output_img[mask == 0, 3] = 0
result['mask'] = mask
return result
def apply_filters(mask: np.array, closing_kernel: tuple = (5, 5), opening_kernel: tuple = (5, 5),
blur_kernel: tuple = (3, 3), bilateral_params: tuple = (9, 75, 75),
min_area: int = 2000) -> np.array:
mask = mask.astype(np.uint8)
closing_element = np.ones(closing_kernel, np.uint8)
opening_element = np.ones(opening_kernel, np.uint8)
closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, closing_element)
opened_mask = cv2.morphologyEx(closed_mask, cv2.MORPH_OPEN, opening_element)
smoothed_mask = cv2.GaussianBlur(opened_mask, blur_kernel, 0)
edge_smoothed_mask = cv2.bilateralFilter(smoothed_mask, *bilateral_params)
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edge_smoothed_mask, connectivity=8)
large_component_mask = np.zeros_like(edge_smoothed_mask)
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] >= min_area:
large_component_mask[labels == i] = 255
return large_component_mask
def matting_interface(input_image, apply_morphology):
input_image = np.array(input_image)
input_image = input_image[:, :, ::-1]
pipeline = ImageMattingPipeline(model_dir='cv_unet_universal-matting')
preprocessed = pipeline.preprocess(input_image)
result = pipeline.forward(preprocessed, output_mask=True)
if apply_morphology:
mask = apply_filters(result['mask'])
else:
mask = result.get('mask', None)
output_img_pil = Image.fromarray(result['output_img'].astype(np.uint8))
mask_pil = Image.fromarray(mask) if mask is not None else None
return output_img_pil, mask_pil
iface = gr.Interface(
fn=matting_interface,
inputs=[
gr.components.Image(type="pil", image_mode="RGB"),
gr.components.Checkbox(label="Apply Morphological Processing for Mask")
],
outputs=[
gr.components.Image(type="pil", label="Matting Result"),
gr.components.Image(type="pil", label="Mask"),
],
title="Image Matting with Morphological Processing Option",
description="Upload an image to get the matting result and mask. "
"Use the checkbox to enable or disable morphological processing on the mask."
)
iface.launch()