petergpt commited on
Commit
c9473c9
Β·
verified Β·
1 Parent(s): 1018e38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -15
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from transformers import AutoModelForImageSegmentation
3
  from PIL import Image
@@ -18,25 +19,61 @@ transform_image = transforms.Compose([
18
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
19
  ])
20
 
21
- def extract_object(image):
22
- input_images = transform_image(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
23
  with torch.no_grad():
24
- preds = birefnet(input_images)[-1].sigmoid().cpu()
25
- pred = preds[0].squeeze()
26
- pred_pil = transforms.ToPILImage()(pred)
27
- mask = pred_pil.resize(image.size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Create a new transparent image
30
- result = Image.new("RGBA", image.size, (0, 0, 0, 0))
31
- result.paste(image, mask=mask)
32
- return result
33
 
34
  iface = gr.Interface(
35
- fn=extract_object,
36
- inputs=gr.Image(type="pil", label="Upload Image"),
37
- outputs=gr.Image(type="pil", label="Object with Transparent Background"),
38
- title="BiRefNet Background Removal",
39
- description="Upload an image and get the foreground object extracted onto a transparent background."
40
  )
41
 
42
  if __name__ == "__main__":
 
1
+ import time
2
  import torch
3
  from transformers import AutoModelForImageSegmentation
4
  from PIL import Image
 
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
+ def extract_objects(images):
23
+ start_time = time.time()
24
+
25
+ # Transform all images into a batch
26
+ inputs = []
27
+ original_sizes = []
28
+ for img in images:
29
+ original_sizes.append(img.size)
30
+ inputs.append(transform_image(img))
31
+ input_tensor = torch.stack(inputs).to(device)
32
+
33
+ # Inference
34
+ inf_start = time.time()
35
  with torch.no_grad():
36
+ preds = birefnet(input_tensor)[-1].sigmoid().cpu()
37
+ inf_end = time.time()
38
+
39
+ # Post-process results
40
+ results = []
41
+ image_times = []
42
+ for i, img in enumerate(images):
43
+ t_start = time.time()
44
+ pred = preds[i].squeeze()
45
+ pred_pil = transforms.ToPILImage()(pred)
46
+ mask = pred_pil.resize(original_sizes[i])
47
+
48
+ # Create a transparent background image
49
+ result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
50
+ result.paste(img, mask=mask)
51
+ results.append(result)
52
+ t_end = time.time()
53
+ image_times.append(t_end - t_start)
54
+
55
+ end_time = time.time()
56
+ total_time = end_time - start_time
57
+ inference_time = inf_end - inf_start
58
+ prep_post_time = total_time - inference_time
59
+
60
+ # Create a summary of timings
61
+ summary = (
62
+ f"Total request time: {total_time:.2f} s\n"
63
+ f"Inference time (batch): {inference_time:.2f} s\n"
64
+ f"Pre/Post-processing time: {prep_post_time:.2f} s\n"
65
+ "Per-image post-processing times:\n" +
66
+ "\n".join([f" Image {i+1}: {t:.2f} s" for i, t in enumerate(image_times)])
67
+ )
68
 
69
+ return results, summary
 
 
 
70
 
71
  iface = gr.Interface(
72
+ fn=extract_objects,
73
+ inputs=gr.Image(type="pil", label="Upload Images", multiple=True),
74
+ outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
75
+ title="BiRefNet Bulk Background Removal",
76
+ description="Upload multiple images and process them in bulk. The request is handled at once, not sequentially. Timing information is also provided."
77
  )
78
 
79
  if __name__ == "__main__":