petergpt commited on
Commit
ea3db16
·
verified ·
1 Parent(s): 3472d22

remove mask

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -98,6 +98,7 @@ net = build_model(hypar, device)
98
  def inference(file_paths, logs):
99
  """
100
  Process up to 3 images uploaded via the file uploader.
 
101
  """
102
  start_time = time.time()
103
  logs = logs or ""
@@ -107,7 +108,7 @@ def inference(file_paths, logs):
107
 
108
  # Limit to a maximum of 3 images
109
  image_paths = file_paths[:3]
110
- processed_pairs = []
111
  for path in image_paths:
112
  image_tensor, orig_size = load_image(path, hypar)
113
  mask = predict(net, image_tensor, orig_size, hypar, device)
@@ -115,12 +116,11 @@ def inference(file_paths, logs):
115
  im_rgb = Image.open(path).convert("RGB")
116
  im_rgba = im_rgb.copy()
117
  im_rgba.putalpha(pil_mask)
118
- processed_pairs.append([im_rgba, pil_mask])
119
 
120
  elapsed = round(time.time() - start_time, 2)
121
- final_images = [img for pair in processed_pairs for img in pair]
122
- logs += f"Processed {len(processed_pairs)} image(s) in {elapsed} second(s).\n"
123
- return final_images, logs, logs
124
 
125
  title = "Highly Accurate Dichotomous Image Segmentation"
126
  description = (
@@ -142,7 +142,7 @@ interface = gr.Interface(
142
  gr.State()
143
  ],
144
  outputs=[
145
- gr.Gallery(label="Output (rgba + mask)"),
146
  gr.State(),
147
  gr.Textbox(label="Logs", lines=6)
148
  ],
 
98
  def inference(file_paths, logs):
99
  """
100
  Process up to 3 images uploaded via the file uploader.
101
+ Only the image with background removed is returned.
102
  """
103
  start_time = time.time()
104
  logs = logs or ""
 
108
 
109
  # Limit to a maximum of 3 images
110
  image_paths = file_paths[:3]
111
+ processed_images = []
112
  for path in image_paths:
113
  image_tensor, orig_size = load_image(path, hypar)
114
  mask = predict(net, image_tensor, orig_size, hypar, device)
 
116
  im_rgb = Image.open(path).convert("RGB")
117
  im_rgba = im_rgb.copy()
118
  im_rgba.putalpha(pil_mask)
119
+ processed_images.append(im_rgba)
120
 
121
  elapsed = round(time.time() - start_time, 2)
122
+ logs += f"Processed {len(processed_images)} image(s) in {elapsed} second(s).\n"
123
+ return processed_images, logs, logs
 
124
 
125
  title = "Highly Accurate Dichotomous Image Segmentation"
126
  description = (
 
142
  gr.State()
143
  ],
144
  outputs=[
145
+ gr.Gallery(label="Output (Background Removed)"),
146
  gr.State(),
147
  gr.Textbox(label="Logs", lines=6)
148
  ],