Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
from transformers import AutoModelForImageSegmentation
|
3 |
from PIL import Image
|
@@ -18,25 +19,61 @@ transform_image = transforms.Compose([
|
|
18 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
19 |
])
|
20 |
|
21 |
-
def
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
with torch.no_grad():
|
24 |
-
preds = birefnet(
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
result = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
31 |
-
result.paste(image, mask=mask)
|
32 |
-
return result
|
33 |
|
34 |
iface = gr.Interface(
|
35 |
-
fn=
|
36 |
-
inputs=gr.Image(type="pil", label="Upload
|
37 |
-
outputs=gr.
|
38 |
-
title="BiRefNet Background Removal",
|
39 |
-
description="Upload
|
40 |
)
|
41 |
|
42 |
if __name__ == "__main__":
|
|
|
1 |
+
import time
|
2 |
import torch
|
3 |
from transformers import AutoModelForImageSegmentation
|
4 |
from PIL import Image
|
|
|
19 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
20 |
])
|
21 |
|
22 |
+
def extract_objects(images):
|
23 |
+
start_time = time.time()
|
24 |
+
|
25 |
+
# Transform all images into a batch
|
26 |
+
inputs = []
|
27 |
+
original_sizes = []
|
28 |
+
for img in images:
|
29 |
+
original_sizes.append(img.size)
|
30 |
+
inputs.append(transform_image(img))
|
31 |
+
input_tensor = torch.stack(inputs).to(device)
|
32 |
+
|
33 |
+
# Inference
|
34 |
+
inf_start = time.time()
|
35 |
with torch.no_grad():
|
36 |
+
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
|
37 |
+
inf_end = time.time()
|
38 |
+
|
39 |
+
# Post-process results
|
40 |
+
results = []
|
41 |
+
image_times = []
|
42 |
+
for i, img in enumerate(images):
|
43 |
+
t_start = time.time()
|
44 |
+
pred = preds[i].squeeze()
|
45 |
+
pred_pil = transforms.ToPILImage()(pred)
|
46 |
+
mask = pred_pil.resize(original_sizes[i])
|
47 |
+
|
48 |
+
# Create a transparent background image
|
49 |
+
result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
|
50 |
+
result.paste(img, mask=mask)
|
51 |
+
results.append(result)
|
52 |
+
t_end = time.time()
|
53 |
+
image_times.append(t_end - t_start)
|
54 |
+
|
55 |
+
end_time = time.time()
|
56 |
+
total_time = end_time - start_time
|
57 |
+
inference_time = inf_end - inf_start
|
58 |
+
prep_post_time = total_time - inference_time
|
59 |
+
|
60 |
+
# Create a summary of timings
|
61 |
+
summary = (
|
62 |
+
f"Total request time: {total_time:.2f} s\n"
|
63 |
+
f"Inference time (batch): {inference_time:.2f} s\n"
|
64 |
+
f"Pre/Post-processing time: {prep_post_time:.2f} s\n"
|
65 |
+
"Per-image post-processing times:\n" +
|
66 |
+
"\n".join([f" Image {i+1}: {t:.2f} s" for i, t in enumerate(image_times)])
|
67 |
+
)
|
68 |
|
69 |
+
return results, summary
|
|
|
|
|
|
|
70 |
|
71 |
iface = gr.Interface(
|
72 |
+
fn=extract_objects,
|
73 |
+
inputs=gr.Image(type="pil", label="Upload Images", multiple=True),
|
74 |
+
outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
|
75 |
+
title="BiRefNet Bulk Background Removal",
|
76 |
+
description="Upload multiple images and process them in bulk. The request is handled at once, not sequentially. Timing information is also provided."
|
77 |
)
|
78 |
|
79 |
if __name__ == "__main__":
|