petergpt commited on
Commit
aff9838
Β·
verified Β·
1 Parent(s): c9473c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoModelForImageSegmentation
4
  from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
 
7
 
8
  # Load the model from Hugging Face
9
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
@@ -11,7 +12,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  birefnet.to(device)
12
  birefnet.eval()
13
 
14
- # Define the transform to preprocess the input image
15
  image_size = (1024, 1024)
16
  transform_image = transforms.Compose([
17
  transforms.Resize(image_size),
@@ -19,10 +20,11 @@ transform_image = transforms.Compose([
19
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
  ])
21
 
22
- def extract_objects(images):
23
- start_time = time.time()
 
24
 
25
- # Transform all images into a batch
26
  inputs = []
27
  original_sizes = []
28
  for img in images:
@@ -70,10 +72,10 @@ def extract_objects(images):
70
 
71
  iface = gr.Interface(
72
  fn=extract_objects,
73
- inputs=gr.Image(type="pil", label="Upload Images", multiple=True),
74
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
75
  title="BiRefNet Bulk Background Removal",
76
- description="Upload multiple images and process them in bulk. The request is handled at once, not sequentially. Timing information is also provided."
77
  )
78
 
79
  if __name__ == "__main__":
 
4
  from PIL import Image
5
  from torchvision import transforms
6
  import gradio as gr
7
+ import os
8
 
9
  # Load the model from Hugging Face
10
  birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet_lite', trust_remote_code=True)
 
12
  birefnet.to(device)
13
  birefnet.eval()
14
 
15
+ # Define the transform to preprocess the input images
16
  image_size = (1024, 1024)
17
  transform_image = transforms.Compose([
18
  transforms.Resize(image_size),
 
20
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
21
  ])
22
 
23
+ def extract_objects(files):
24
+ # Open all images from the uploaded files
25
+ images = [Image.open(file.name).convert("RGB") for file in files]
26
 
27
+ start_time = time.time()
28
  inputs = []
29
  original_sizes = []
30
  for img in images:
 
72
 
73
  iface = gr.Interface(
74
  fn=extract_objects,
75
+ inputs=gr.Files(label="Upload Multiple Images", type="file", file_count="multiple"),
76
  outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
77
  title="BiRefNet Bulk Background Removal",
78
+ description="Upload multiple images and process them in one request. Timing information for the full request and per-image processing is provided."
79
  )
80
 
81
  if __name__ == "__main__":