Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
|
8 |
-
from PIL import Image
|
9 |
from transformers import AutoModelForImageSegmentation
|
10 |
from torchvision import transforms
|
11 |
|
@@ -78,23 +78,38 @@ def remove_background(image):
|
|
78 |
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
79 |
pred = preds[0].squeeze()
|
80 |
|
81 |
-
#
|
82 |
pred_pil = transforms.ToPILImage()(pred)
|
83 |
pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
torch.cuda.empty_cache()
|
88 |
|
89 |
-
# Save
|
90 |
mask_path = "mask.png"
|
91 |
pred_pil.save(mask_path)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
return mask_path,
|
98 |
|
99 |
css = """
|
100 |
body {
|
@@ -139,7 +154,9 @@ iface = gr.Interface(
|
|
139 |
inputs=gr.Image(type="numpy"),
|
140 |
outputs=[
|
141 |
gr.Image(type="filepath", label="Mask"),
|
142 |
-
gr.Image(type="filepath", label="
|
|
|
|
|
143 |
],
|
144 |
allow_flagging="never",
|
145 |
css=css
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
|
8 |
+
from PIL import Image, ImageOps
|
9 |
from transformers import AutoModelForImageSegmentation
|
10 |
from torchvision import transforms
|
11 |
|
|
|
78 |
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
79 |
pred = preds[0].squeeze()
|
80 |
|
81 |
+
# Process Results
|
82 |
pred_pil = transforms.ToPILImage()(pred)
|
83 |
pred_pil = pred_pil.resize(original_size, Image.BICUBIC) # Resize mask to original size
|
84 |
+
|
85 |
+
# Create reverse mask
|
86 |
+
reverse_mask = Image.new('L', original_size)
|
87 |
+
reverse_mask.paste(ImageOps.invert(pred_pil))
|
88 |
+
|
89 |
+
# Create foreground image (object with transparent background)
|
90 |
+
foreground = image_ori.copy()
|
91 |
+
foreground.putalpha(pred_pil)
|
92 |
+
|
93 |
+
# Create background image
|
94 |
+
background = image_ori.copy()
|
95 |
+
background.putalpha(reverse_mask)
|
96 |
|
97 |
torch.cuda.empty_cache()
|
98 |
|
99 |
+
# Save all images
|
100 |
mask_path = "mask.png"
|
101 |
pred_pil.save(mask_path)
|
102 |
|
103 |
+
reverse_mask_path = "reverse_mask.png"
|
104 |
+
reverse_mask.save(reverse_mask_path)
|
105 |
+
|
106 |
+
foreground_path = "foreground.png"
|
107 |
+
foreground.save(foreground_path)
|
108 |
+
|
109 |
+
background_path = "background.png"
|
110 |
+
background.save(background_path)
|
111 |
|
112 |
+
return mask_path, reverse_mask_path, foreground_path, background_path
|
113 |
|
114 |
css = """
|
115 |
body {
|
|
|
154 |
inputs=gr.Image(type="numpy"),
|
155 |
outputs=[
|
156 |
gr.Image(type="filepath", label="Mask"),
|
157 |
+
gr.Image(type="filepath", label="Reverse Mask"),
|
158 |
+
gr.Image(type="filepath", label="Foreground"),
|
159 |
+
gr.Image(type="filepath", label="Background")
|
160 |
],
|
161 |
allow_flagging="never",
|
162 |
css=css
|