Spaces:
Runtime error
Runtime error
yourusername
commited on
Commit
Β·
eabc299
1
Parent(s):
3e95b67
:sparkles: return foreground only
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ def get_scale_factor(im_h, im_w, ref_size=512):
|
|
33 |
MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
|
34 |
|
35 |
|
36 |
-
def main(image_path):
|
37 |
|
38 |
# read image
|
39 |
im = cv2.imread(image_path)
|
@@ -85,9 +85,17 @@ def main(image_path):
|
|
85 |
image = np.repeat(image, 3, axis=2)
|
86 |
elif image.shape[2] == 4:
|
87 |
image = image[:, :, 0:3]
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
|
93 |
title = "MODNet Background Remover"
|
@@ -99,9 +107,12 @@ image = Image.open(requests.get(url, stream=True).raw)
|
|
99 |
image.save('twitter_profile_pic.jpeg')
|
100 |
interface = gr.Interface(
|
101 |
fn=main,
|
102 |
-
inputs=
|
|
|
|
|
|
|
103 |
outputs='image',
|
104 |
-
examples=[['twitter_profile_pic.jpeg']],
|
105 |
title=title,
|
106 |
description=description,
|
107 |
article=article,
|
|
|
33 |
MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
|
34 |
|
35 |
|
36 |
+
def main(image_path, threshold):
|
37 |
|
38 |
# read image
|
39 |
im = cv2.imread(image_path)
|
|
|
85 |
image = np.repeat(image, 3, axis=2)
|
86 |
elif image.shape[2] == 4:
|
87 |
image = image[:, :, 0:3]
|
88 |
+
|
89 |
+
b, g, r = cv2.split(image)
|
90 |
+
|
91 |
+
mask = np.asarray(matte)
|
92 |
+
a = np.ones(mask.shape, dtype='uint8') * 255
|
93 |
+
alpha_im = cv2.merge([b, g, r, a], 4)
|
94 |
+
bg = np.zeros(alpha_im.shape)
|
95 |
+
new_mask = np.stack([mask, mask, mask, mask], axis=2)
|
96 |
+
foreground = np.where(new_mask > threshold, alpha_im, bg).astype(np.uint8)
|
97 |
+
|
98 |
+
return Image.fromarray(foreground)
|
99 |
|
100 |
|
101 |
title = "MODNet Background Remover"
|
|
|
107 |
image.save('twitter_profile_pic.jpeg')
|
108 |
interface = gr.Interface(
|
109 |
fn=main,
|
110 |
+
inputs=[
|
111 |
+
gr.inputs.Image(type='filepath'),
|
112 |
+
gr.inputs.Slider(minimum=0, maximum=250, default=100, step=5, label='Mask Cutoff Threshold'),
|
113 |
+
],
|
114 |
outputs='image',
|
115 |
+
examples=[['twitter_profile_pic.jpeg', 120]],
|
116 |
title=title,
|
117 |
description=description,
|
118 |
article=article,
|