ghostsInTheMachine commited on
Commit
ab98f09
·
verified ·
1 Parent(s): a0c2c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import gradio as gr
6
  import spaces
7
 
8
- from PIL import Image
9
  from transformers import AutoModelForImageSegmentation
10
  from torchvision import transforms
11
 
@@ -78,23 +78,38 @@ def remove_background(image):
78
  preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
79
  pred = preds[0].squeeze()
80
 
81
- # Show Results
82
  pred_pil = transforms.ToPILImage()(pred)
83
  pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
84
- image_masked = refine_foreground(image_ori, pred_pil)
85
- image_masked.putalpha(pred_pil)
 
 
 
 
 
 
 
 
 
 
86
 
87
  torch.cuda.empty_cache()
88
 
89
- # Save mask as PNG
90
  mask_path = "mask.png"
91
  pred_pil.save(mask_path)
92
 
93
- # Save output as PNG
94
- output_path = "output.png"
95
- image_masked.save(output_path)
 
 
 
 
 
96
 
97
- return mask_path, output_path
98
 
99
  css = """
100
  body {
@@ -139,7 +154,9 @@ iface = gr.Interface(
139
  inputs=gr.Image(type="numpy"),
140
  outputs=[
141
  gr.Image(type="filepath", label="Mask"),
142
- gr.Image(type="filepath", label="Output")
 
 
143
  ],
144
  allow_flagging="never",
145
  css=css
 
5
  import gradio as gr
6
  import spaces
7
 
8
+ from PIL import Image, ImageOps
9
  from transformers import AutoModelForImageSegmentation
10
  from torchvision import transforms
11
 
 
78
  preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
79
  pred = preds[0].squeeze()
80
 
81
+ # Process Results
82
  pred_pil = transforms.ToPILImage()(pred)
83
  pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
84
+
85
+ # Create reverse mask
86
+ reverse_mask = Image.new('L', original_size)
87
+ reverse_mask.paste(ImageOps.invert(pred_pil))
88
+
89
+ # Create foreground image (object with transparent background)
90
+ foreground = image_ori.copy()
91
+ foreground.putalpha(pred_pil)
92
+
93
+ # Create background image
94
+ background = image_ori.copy()
95
+ background.putalpha(reverse_mask)
96
 
97
  torch.cuda.empty_cache()
98
 
99
+ # Save all images
100
  mask_path = "mask.png"
101
  pred_pil.save(mask_path)
102
 
103
+ reverse_mask_path = "reverse_mask.png"
104
+ reverse_mask.save(reverse_mask_path)
105
+
106
+ foreground_path = "foreground.png"
107
+ foreground.save(foreground_path)
108
+
109
+ background_path = "background.png"
110
+ background.save(background_path)
111
 
112
+ return mask_path, reverse_mask_path, foreground_path, background_path
113
 
114
  css = """
115
  body {
 
154
  inputs=gr.Image(type="numpy"),
155
  outputs=[
156
  gr.Image(type="filepath", label="Mask"),
157
+ gr.Image(type="filepath", label="Reverse Mask"),
158
+ gr.Image(type="filepath", label="Foreground"),
159
+ gr.Image(type="filepath", label="Background")
160
  ],
161
  allow_flagging="never",
162
  css=css