Spaces:
Running
on
Zero
Running
on
Zero
Add tab of batch inference with saving function.
Browse files
app.py
CHANGED
@@ -57,15 +57,37 @@ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7',
|
|
57 |
birefnet.to(device)
|
58 |
birefnet.eval()
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
@spaces.GPU
|
62 |
-
def predict(
|
63 |
-
assert (
|
64 |
-
|
65 |
-
if isinstance(image, str):
|
66 |
-
response = requests.get(image)
|
67 |
-
image_data = BytesIO(response.content)
|
68 |
-
image = np.array(Image.open(image_data))
|
69 |
global birefnet
|
70 |
# Load BiRefNet with chosen weights
|
71 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
@@ -74,33 +96,63 @@ def predict(image, resolution, weights_file):
|
|
74 |
birefnet.to(device)
|
75 |
birefnet.eval()
|
76 |
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
# Preprocess the image
|
84 |
-
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
85 |
-
image_proc = image_preprocessor.proc(image_pil)
|
86 |
-
image_proc = image_proc.unsqueeze(0)
|
87 |
-
|
88 |
-
# Perform the prediction
|
89 |
-
with torch.no_grad():
|
90 |
-
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
91 |
-
|
92 |
-
if device == 'cuda':
|
93 |
-
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
|
105 |
return image, image_pred
|
106 |
|
@@ -118,6 +170,11 @@ examples_url = [
|
|
118 |
for idx_example_url, example_url in enumerate(examples_url):
|
119 |
examples_url[idx_example_url].append('1024x1024')
|
120 |
|
|
|
|
|
|
|
|
|
|
|
121 |
tab_image = gr.Interface(
|
122 |
fn=predict,
|
123 |
inputs=[
|
@@ -128,10 +185,7 @@ tab_image = gr.Interface(
|
|
128 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
129 |
examples=examples,
|
130 |
api_name="image",
|
131 |
-
description=
|
132 |
-
' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
|
133 |
-
' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
|
134 |
-
' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.'),
|
135 |
)
|
136 |
|
137 |
tab_text = gr.Interface(
|
@@ -144,15 +198,20 @@ tab_text = gr.Interface(
|
|
144 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
145 |
examples=examples_url,
|
146 |
api_name="text",
|
147 |
-
description=
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
|
153 |
demo = gr.TabbedInterface(
|
154 |
-
[tab_image, tab_text],
|
155 |
-
[
|
156 |
title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
|
157 |
)
|
158 |
|
|
|
57 |
birefnet.to(device)
|
58 |
birefnet.eval()
|
59 |
|
60 |
+
# for idx, image_path in enumerate(images):
|
61 |
+
# im = load_img(image_path, output_type="pil")
|
62 |
+
# if im is None:
|
63 |
+
# continue
|
64 |
+
|
65 |
+
# im = im.convert("RGB")
|
66 |
+
# image_size = im.size
|
67 |
+
# input_images = transform_image(im).unsqueeze(0).to("cpu")
|
68 |
+
|
69 |
+
# with torch.no_grad():
|
70 |
+
# preds = birefnet(input_images)[-1].sigmoid().cpu()
|
71 |
+
# pred = preds[0].squeeze()
|
72 |
+
# pred_pil = transforms.ToPILImage()(pred)
|
73 |
+
# mask = pred_pil.resize(image_size)
|
74 |
+
|
75 |
+
# im.putalpha(mask)
|
76 |
+
# output_file_path = os.path.join(save_dir, f"output_image_batch_{idx + 1}.png")
|
77 |
+
# im.save(output_file_path)
|
78 |
+
# output_paths.append(output_file_path)
|
79 |
+
|
80 |
+
# zip_file_path = os.path.join(save_dir, "processed_images.zip")
|
81 |
+
# with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
82 |
+
# for file in output_paths:
|
83 |
+
# zipf.write(file, os.path.basename(file))
|
84 |
+
|
85 |
+
# return output_paths, zip_file_path
|
86 |
|
87 |
@spaces.GPU
|
88 |
+
def predict(images, resolution, weights_file):
|
89 |
+
assert (images is not None), 'AssertionError: images cannot be None.'
|
90 |
+
|
|
|
|
|
|
|
|
|
91 |
global birefnet
|
92 |
# Load BiRefNet with chosen weights
|
93 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
|
|
96 |
birefnet.to(device)
|
97 |
birefnet.eval()
|
98 |
|
99 |
+
try:
|
100 |
+
resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
|
101 |
+
except:
|
102 |
+
resolution = [1024, 1024]
|
103 |
+
print('Invalid resolution input. Automatically changed to 1024x1024.')
|
104 |
+
|
105 |
+
if isinstance(images, list):
|
106 |
+
save_dir = 'preds-BiRefNet'
|
107 |
+
if not os.path.exists(save_dir):
|
108 |
+
os.makedirs(save_dir)
|
109 |
+
else:
|
110 |
+
# For tab_batch
|
111 |
+
save_paths = []
|
112 |
+
images = [images]
|
113 |
+
|
114 |
+
for idx_image, image_src in enumerate(images):
|
115 |
+
if isinstance(image_src, str):
|
116 |
+
response = requests.get(image_src)
|
117 |
+
image_data = BytesIO(response.content)
|
118 |
+
image = np.array(Image.open(image_data))
|
119 |
+
else:
|
120 |
+
image = image_src
|
121 |
|
122 |
+
image_shape = image.shape[:2]
|
123 |
+
image_pil = array_to_pil_image(image, tuple(resolution))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
+
# Preprocess the image
|
126 |
+
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
127 |
+
image_proc = image_preprocessor.proc(image_pil)
|
128 |
+
image_proc = image_proc.unsqueeze(0)
|
129 |
+
|
130 |
+
# Perform the prediction
|
131 |
+
with torch.no_grad():
|
132 |
+
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
133 |
+
|
134 |
+
if device == 'cuda':
|
135 |
+
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
136 |
+
|
137 |
+
# Resize the prediction to match the original image shape
|
138 |
+
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
139 |
+
|
140 |
+
# Apply the prediction mask to the original image
|
141 |
+
image_pil = image_pil.resize(pred.shape[::-1])
|
142 |
+
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
143 |
+
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
|
144 |
+
|
145 |
+
torch.cuda.empty_cache()
|
146 |
+
|
147 |
+
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
148 |
+
cv2.imwrite(save_file_path)
|
149 |
+
save_paths.append(save_file_path)
|
150 |
|
151 |
+
if len(images) > 1:
|
152 |
+
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
153 |
+
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
154 |
+
for file in save_paths:
|
155 |
+
zipf.write(file, os.path.basename(file))
|
156 |
|
157 |
return image, image_pred
|
158 |
|
|
|
170 |
for idx_example_url, example_url in enumerate(examples_url):
|
171 |
examples_url[idx_example_url].append('1024x1024')
|
172 |
|
173 |
+
descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
|
174 |
+
' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
|
175 |
+
' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
|
176 |
+
' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
|
177 |
+
|
178 |
tab_image = gr.Interface(
|
179 |
fn=predict,
|
180 |
inputs=[
|
|
|
185 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
186 |
examples=examples,
|
187 |
api_name="image",
|
188 |
+
description=descriptions,
|
|
|
|
|
|
|
189 |
)
|
190 |
|
191 |
tab_text = gr.Interface(
|
|
|
198 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
199 |
examples=examples_url,
|
200 |
api_name="text",
|
201 |
+
description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
|
202 |
+
)
|
203 |
+
|
204 |
+
tab_batch = gr.Interface(
|
205 |
+
fn=predict,
|
206 |
+
inputs=gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
|
207 |
+
outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
|
208 |
+
api_name="batch",
|
209 |
+
description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
|
210 |
)
|
211 |
|
212 |
demo = gr.TabbedInterface(
|
213 |
+
[tab_image, tab_text, tab_batch],
|
214 |
+
['image', 'text', 'batch'],
|
215 |
title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
|
216 |
)
|
217 |
|