yourusername commited on
Commit
eabc299
Β·
1 Parent(s): 3e95b67

:sparkles: return foreground only

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -33,7 +33,7 @@ def get_scale_factor(im_h, im_w, ref_size=512):
33
  MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
34
 
35
 
36
- def main(image_path):
37
 
38
  # read image
39
  im = cv2.imread(image_path)
@@ -85,9 +85,17 @@ def main(image_path):
85
  image = np.repeat(image, 3, axis=2)
86
  elif image.shape[2] == 4:
87
  image = image[:, :, 0:3]
88
- matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
89
- foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
90
- return Image.fromarray(foreground.astype(np.uint8))
 
 
 
 
 
 
 
 
91
 
92
 
93
  title = "MODNet Background Remover"
@@ -99,9 +107,12 @@ image = Image.open(requests.get(url, stream=True).raw)
99
  image.save('twitter_profile_pic.jpeg')
100
  interface = gr.Interface(
101
  fn=main,
102
- inputs=gr.inputs.Image(type='filepath'),
 
 
 
103
  outputs='image',
104
- examples=[['twitter_profile_pic.jpeg']],
105
  title=title,
106
  description=description,
107
  article=article,
 
33
  MODEL_PATH = hf_hub_download('nateraw/background-remover-files', 'modnet.onnx', repo_type='dataset')
34
 
35
 
36
+ def main(image_path, threshold):
37
 
38
  # read image
39
  im = cv2.imread(image_path)
 
85
  image = np.repeat(image, 3, axis=2)
86
  elif image.shape[2] == 4:
87
  image = image[:, :, 0:3]
88
+
89
+ b, g, r = cv2.split(image)
90
+
91
+ mask = np.asarray(matte)
92
+ a = np.ones(mask.shape, dtype='uint8') * 255
93
+ alpha_im = cv2.merge([b, g, r, a], 4)
94
+ bg = np.zeros(alpha_im.shape)
95
+ new_mask = np.stack([mask, mask, mask, mask], axis=2)
96
+ foreground = np.where(new_mask > threshold, alpha_im, bg).astype(np.uint8)
97
+
98
+ return Image.fromarray(foreground)
99
 
100
 
101
  title = "MODNet Background Remover"
 
107
  image.save('twitter_profile_pic.jpeg')
108
  interface = gr.Interface(
109
  fn=main,
110
+ inputs=[
111
+ gr.inputs.Image(type='filepath'),
112
+ gr.inputs.Slider(minimum=0, maximum=250, default=100, step=5, label='Mask Cutoff Threshold'),
113
+ ],
114
  outputs='image',
115
+ examples=[['twitter_profile_pic.jpeg', 120]],
116
  title=title,
117
  description=description,
118
  article=article,