rdjarbeng commited on
Commit
e06181a
·
1 Parent(s): 09cdf11

fix mismatch in remove, use default remove

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -8,11 +8,11 @@ logging.basicConfig(level=logging.INFO)
8
 
9
  def remove_background(input_image, bg_color, model_name, alpha_matting, post_process_mask, only_mask):
10
  try:
11
- # Set up the session with the chosen model
12
  session = new_session(model_name) if model_name else None
13
 
14
  # Prepare additional options
15
- remove_options = {
16
  "session": session,
17
  "bgcolor": bg_color if bg_color else None,
18
  "alpha_matting": alpha_matting,
@@ -20,8 +20,12 @@ def remove_background(input_image, bg_color, model_name, alpha_matting, post_pro
20
  "only_mask": only_mask
21
  }
22
 
23
- # Remove the background
24
- output_image = remove(input_image, **{k: v for k, v in remove_options.items() if v is not None})
 
 
 
 
25
  logging.info("Background removed")
26
 
27
  # Convert to RGB mode if necessary
@@ -46,7 +50,7 @@ iface = gr.Interface(
46
  inputs=[
47
  gr.Image(type="pil"),
48
  gr.ColorPicker(label="Background Color", value=None), # Background color picker
49
- gr.Dropdown(choices=["u2net", "isnet-general-use", "unet"], label="Model Selection", value="u2net"),
50
  gr.Checkbox(label="Enable Alpha Matting", value=False),
51
  gr.Checkbox(label="Post-Process Mask", value=False),
52
  gr.Checkbox(label="Only Return Mask", value=False)
 
8
 
9
  def remove_background(input_image, bg_color, model_name, alpha_matting, post_process_mask, only_mask):
10
  try:
11
+ # Set up the session with the chosen model, or None if no model is selected
12
  session = new_session(model_name) if model_name else None
13
 
14
  # Prepare additional options
15
+ remove_kwargs = {
16
  "session": session,
17
  "bgcolor": bg_color if bg_color else None,
18
  "alpha_matting": alpha_matting,
 
20
  "only_mask": only_mask
21
  }
22
 
23
+ # Use the remove function
24
+ if session:
25
+ output_image = remove(input_image, **{k: v for k, v in remove_kwargs.items() if v is not None})
26
+ else:
27
+ output_image = remove(input_image) # Use the default remove function
28
+
29
  logging.info("Background removed")
30
 
31
  # Convert to RGB mode if necessary
 
50
  inputs=[
51
  gr.Image(type="pil"),
52
  gr.ColorPicker(label="Background Color", value=None), # Background color picker
53
+ gr.Dropdown(choices=["", "u2net", "isnet-general-use", "unet"], label="Model Selection", value=""),
54
  gr.Checkbox(label="Enable Alpha Matting", value=False),
55
  gr.Checkbox(label="Post-Process Mask", value=False),
56
  gr.Checkbox(label="Only Return Mask", value=False)