Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
import tensorflow as tf
|
7 |
+
|
8 |
+
if tf.__version__ >= '2.0':
|
9 |
+
tf = tf.compat.v1
|
10 |
+
|
11 |
+
class ImageMattingPipeline:
|
12 |
+
def __init__(self, model_dir: str, input_name: str = 'input_image:0', output_name: str = 'output_png:0'):
|
13 |
+
model_path = os.path.join(model_dir, 'tf_graph.pb')
|
14 |
+
if not os.path.exists(model_path):
|
15 |
+
raise FileNotFoundError("Model file not found at {}".format(model_path))
|
16 |
+
config = tf.ConfigProto(allow_soft_placement=True)
|
17 |
+
config.gpu_options.allow_growth = True
|
18 |
+
self.graph = tf.Graph()
|
19 |
+
with self.graph.as_default():
|
20 |
+
self._session = tf.Session(config=config)
|
21 |
+
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
22 |
+
graph_def = tf.GraphDef()
|
23 |
+
graph_def.ParseFromString(f.read())
|
24 |
+
tf.import_graph_def(graph_def, name='')
|
25 |
+
self.output = self._session.graph.get_tensor_by_name(output_name)
|
26 |
+
self.input_name = input_name
|
27 |
+
|
28 |
+
def preprocess(self, input_image):
|
29 |
+
img = np.array(input_image)
|
30 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
31 |
+
img = img.astype(float)
|
32 |
+
return {'img': img}
|
33 |
+
|
34 |
+
def forward(self, input, output_mask=False, alpha_threshold=128):
|
35 |
+
with self.graph.as_default(), self._session.as_default():
|
36 |
+
feed_dict = {self.input_name: input['img']}
|
37 |
+
output_img = self._session.run(self.output, feed_dict=feed_dict)
|
38 |
+
result = {'output_img': output_img}
|
39 |
+
if output_mask:
|
40 |
+
alpha_channel = output_img[:, :, 3]
|
41 |
+
mask = np.zeros(alpha_channel.shape, dtype=np.uint8)
|
42 |
+
mask[alpha_channel >= alpha_threshold] = 255
|
43 |
+
output_img[mask == 0, 3] = 0
|
44 |
+
result['mask'] = mask
|
45 |
+
return result
|
46 |
+
|
47 |
+
def apply_filters(mask: np.array, closing_kernel: tuple = (5, 5), opening_kernel: tuple = (5, 5),
|
48 |
+
blur_kernel: tuple = (3, 3), bilateral_params: tuple = (9, 75, 75),
|
49 |
+
min_area: int = 2000) -> np.array:
|
50 |
+
mask = mask.astype(np.uint8)
|
51 |
+
closing_element = np.ones(closing_kernel, np.uint8)
|
52 |
+
opening_element = np.ones(opening_kernel, np.uint8)
|
53 |
+
closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, closing_element)
|
54 |
+
opened_mask = cv2.morphologyEx(closed_mask, cv2.MORPH_OPEN, opening_element)
|
55 |
+
smoothed_mask = cv2.GaussianBlur(opened_mask, blur_kernel, 0)
|
56 |
+
edge_smoothed_mask = cv2.bilateralFilter(smoothed_mask, *bilateral_params)
|
57 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edge_smoothed_mask, connectivity=8)
|
58 |
+
large_component_mask = np.zeros_like(edge_smoothed_mask)
|
59 |
+
for i in range(1, num_labels):
|
60 |
+
if stats[i, cv2.CC_STAT_AREA] >= min_area:
|
61 |
+
large_component_mask[labels == i] = 255
|
62 |
+
return large_component_mask
|
63 |
+
|
64 |
+
def matting_interface(input_image, apply_morphology):
|
65 |
+
input_image = np.array(input_image)
|
66 |
+
input_image = input_image[:, :, ::-1]
|
67 |
+
|
68 |
+
pipeline = ImageMattingPipeline(model_dir='cv_unet_universal-matting')
|
69 |
+
preprocessed = pipeline.preprocess(input_image)
|
70 |
+
result = pipeline.forward(preprocessed, output_mask=True)
|
71 |
+
|
72 |
+
if apply_morphology:
|
73 |
+
mask = apply_filters(result['mask'])
|
74 |
+
else:
|
75 |
+
mask = result.get('mask', None)
|
76 |
+
|
77 |
+
output_img_pil = Image.fromarray(result['output_img'].astype(np.uint8))
|
78 |
+
mask_pil = Image.fromarray(mask) if mask is not None else None
|
79 |
+
|
80 |
+
return output_img_pil, mask_pil
|
81 |
+
|
82 |
+
iface = gr.Interface(
|
83 |
+
fn=matting_interface,
|
84 |
+
inputs=[
|
85 |
+
gr.components.Image(type="pil", image_mode="RGB"),
|
86 |
+
gr.components.Checkbox(label="Apply Morphological Processing for Mask")
|
87 |
+
],
|
88 |
+
outputs=[
|
89 |
+
gr.components.Image(type="pil", label="Matting Result"),
|
90 |
+
gr.components.Image(type="pil", label="Mask"),
|
91 |
+
],
|
92 |
+
title="Image Matting with Morphological Processing Option",
|
93 |
+
description="Upload an image to get the matting result and mask. "
|
94 |
+
"Use the checkbox to enable or disable morphological processing on the mask."
|
95 |
+
)
|
96 |
+
|
97 |
+
iface.launch()
|