haor commited on
Commit
850927f
1 Parent(s): 29cd496

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+ import tensorflow as tf
7
+
8
+ if tf.__version__ >= '2.0':
9
+ tf = tf.compat.v1
10
+
11
+ class ImageMattingPipeline:
12
+ def __init__(self, model_dir: str, input_name: str = 'input_image:0', output_name: str = 'output_png:0'):
13
+ model_path = os.path.join(model_dir, 'tf_graph.pb')
14
+ if not os.path.exists(model_path):
15
+ raise FileNotFoundError("Model file not found at {}".format(model_path))
16
+ config = tf.ConfigProto(allow_soft_placement=True)
17
+ config.gpu_options.allow_growth = True
18
+ self.graph = tf.Graph()
19
+ with self.graph.as_default():
20
+ self._session = tf.Session(config=config)
21
+ with tf.gfile.FastGFile(model_path, 'rb') as f:
22
+ graph_def = tf.GraphDef()
23
+ graph_def.ParseFromString(f.read())
24
+ tf.import_graph_def(graph_def, name='')
25
+ self.output = self._session.graph.get_tensor_by_name(output_name)
26
+ self.input_name = input_name
27
+
28
+ def preprocess(self, input_image):
29
+ img = np.array(input_image)
30
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
31
+ img = img.astype(float)
32
+ return {'img': img}
33
+
34
+ def forward(self, input, output_mask=False, alpha_threshold=128):
35
+ with self.graph.as_default(), self._session.as_default():
36
+ feed_dict = {self.input_name: input['img']}
37
+ output_img = self._session.run(self.output, feed_dict=feed_dict)
38
+ result = {'output_img': output_img}
39
+ if output_mask:
40
+ alpha_channel = output_img[:, :, 3]
41
+ mask = np.zeros(alpha_channel.shape, dtype=np.uint8)
42
+ mask[alpha_channel >= alpha_threshold] = 255
43
+ output_img[mask == 0, 3] = 0
44
+ result['mask'] = mask
45
+ return result
46
+
47
+ def apply_filters(mask: np.array, closing_kernel: tuple = (5, 5), opening_kernel: tuple = (5, 5),
48
+ blur_kernel: tuple = (3, 3), bilateral_params: tuple = (9, 75, 75),
49
+ min_area: int = 2000) -> np.array:
50
+ mask = mask.astype(np.uint8)
51
+ closing_element = np.ones(closing_kernel, np.uint8)
52
+ opening_element = np.ones(opening_kernel, np.uint8)
53
+ closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, closing_element)
54
+ opened_mask = cv2.morphologyEx(closed_mask, cv2.MORPH_OPEN, opening_element)
55
+ smoothed_mask = cv2.GaussianBlur(opened_mask, blur_kernel, 0)
56
+ edge_smoothed_mask = cv2.bilateralFilter(smoothed_mask, *bilateral_params)
57
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(edge_smoothed_mask, connectivity=8)
58
+ large_component_mask = np.zeros_like(edge_smoothed_mask)
59
+ for i in range(1, num_labels):
60
+ if stats[i, cv2.CC_STAT_AREA] >= min_area:
61
+ large_component_mask[labels == i] = 255
62
+ return large_component_mask
63
+
64
+ def matting_interface(input_image, apply_morphology):
65
+ input_image = np.array(input_image)
66
+ input_image = input_image[:, :, ::-1]
67
+
68
+ pipeline = ImageMattingPipeline(model_dir='cv_unet_universal-matting')
69
+ preprocessed = pipeline.preprocess(input_image)
70
+ result = pipeline.forward(preprocessed, output_mask=True)
71
+
72
+ if apply_morphology:
73
+ mask = apply_filters(result['mask'])
74
+ else:
75
+ mask = result.get('mask', None)
76
+
77
+ output_img_pil = Image.fromarray(result['output_img'].astype(np.uint8))
78
+ mask_pil = Image.fromarray(mask) if mask is not None else None
79
+
80
+ return output_img_pil, mask_pil
81
+
82
+ iface = gr.Interface(
83
+ fn=matting_interface,
84
+ inputs=[
85
+ gr.components.Image(type="pil", image_mode="RGB"),
86
+ gr.components.Checkbox(label="Apply Morphological Processing for Mask")
87
+ ],
88
+ outputs=[
89
+ gr.components.Image(type="pil", label="Matting Result"),
90
+ gr.components.Image(type="pil", label="Mask"),
91
+ ],
92
+ title="Image Matting with Morphological Processing Option",
93
+ description="Upload an image to get the matting result and mask. "
94
+ "Use the checkbox to enable or disable morphological processing on the mask."
95
+ )
96
+
97
+ iface.launch()