File size: 4,790 Bytes
d73301b
88a3311
5c0ed2c
d73301b
8423e33
d73301b
fac2e01
d73301b
 
 
88a3311
 
4c6f550
 
 
 
 
 
 
2b469d7
4c6f550
 
 
 
 
 
4caae8a
 
 
 
 
 
401fd9e
d73301b
401fd9e
 
 
e06181a
5c0ed2c
 
4caae8a
 
 
5c0ed2c
e06181a
5c0ed2c
4caae8a
5c0ed2c
 
 
 
04c67c9
 
 
 
 
 
 
 
 
5c0ed2c
b709a2a
45070c2
8423e33
 
 
 
e06181a
4caae8a
8423e33
e06181a
8423e33
e06181a
d73301b
 
8423e33
 
 
d73301b
5c0ed2c
d73301b
 
 
9f420c3
fac2e01
 
f9d0a7f
9f420c3
 
 
b1bcf24
d73301b
 
 
b1bcf24
d73301b
e55d440
f9d0a7f
844b5a5
e55d440
d73301b
 
 
5c0ed2c
 
4c6f550
401fd9e
 
 
 
 
 
09cdf11
f9d0a7f
9e90c80
5c0ed2c
b1bcf24
 
 
 
0560d17
 
5c0ed2c
d73301b
 
 
9f420c3
88a3311
4c6f550
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import rembg
from rembg import remove, new_session
from PIL import Image
import numpy as np
import logging
import time

# Set up logging
logging.basicConfig(level=logging.INFO)
# Log the version of rembg
logging.info(f"rembg version: {rembg.__version__}")
# Define model options with separate names and descriptions
MODEL_OPTIONS = {
    "": "Select a model",
    "u2net": "A pre-trained model for general use cases (default)",
    "isnet-general-use": "A new pre-trained model for general use cases",
    "isnet-anime": "High-accuracy segmentation for anime characters",
    "silueta": "A reduced-size version of u2net (43MB)",
    # "sam_prompt": "A pre-trained model for any use case (Segment Anything Model)", # remove sam model
    "unet": "Lightweight version of u2net model",
    "u2netp": "A lightweight version of u2net model",
    "u2net_human_seg": "A pre-trained model for human segmentation",
    "u2net_cloth_seg": "A pre-trained model for cloth parsing in human portraits",
}

def hex_to_rgba(hex_color):
    hex_color = hex_color.lstrip('#')
    if len(hex_color) == 6:
        hex_color += 'FF'  # Add full opacity if no alpha is provided
    return tuple(int(hex_color[i:i + 2], 16) for i in (0, 2, 4, 6))

def remove_background(input_image, bg_color, model_choice, alpha_matting, post_process_mask, only_mask):
    try:
        # Extract the model name from the choice
        model_name = model_choice.split(' | ')[0] if model_choice else ""
        
        # Set up the session with the chosen model, or None if no model is selected
        session = new_session(model_name) if model_name else None
        
        # Convert hex color to RGBA tuple
        bg_color_rgba = hex_to_rgba(bg_color) if bg_color else None
        
        # Prepare additional options
        remove_kwargs = {
            "session": session,
            "bgcolor": bg_color_rgba if bg_color_rgba else None,
            "alpha_matting": alpha_matting,
            "post_process_mask": post_process_mask,
            "only_mask": only_mask
        }

        # Add alpha matting parameters if enabled
        if alpha_matting:
            remove_kwargs.update({
                "alpha_matting": True,
                "alpha_matting_foreground_threshold": 270,
                "alpha_matting_background_threshold": 20,
                "alpha_matting_erode_size": 11
            })
        
        logging.info(f'Model name={model_name}')
        logging.info(remove_kwargs)

        # Convert PIL Image to numpy array
        input_array = np.array(input_image)

        # Use the remove function
        if session or bg_color_rgba:
            output_array = remove(input_array, **{k: v for k, v in remove_kwargs.items() if v is not None})
        else:
            output_array = remove(input_array)  # Use the default remove function

        logging.info("Background removed")

        # Convert numpy array back to PIL Image
        output_image = Image.fromarray(output_array)

        # Convert to RGB mode if necessary
        if not only_mask and output_image.mode != 'RGB':
            output_image = output_image.convert('RGB')
            logging.info("Converted to RGB mode")

        # Save the output image to a temporary file
        # Generate a unique timestamp for the output file name
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        output_path = f"output_remove_background_{timestamp}.png"
        output_image.save(output_path)
        logging.info(f"Saved output image {output_path}")

        return output_image, output_path

    except Exception as e:
        logging.error(f"An error occurred: {e}")
        return None, None

examples = [

    ['scifi_man1.jpg']
]
# Gradio interface
iface = gr.Interface(
    fn=remove_background,
    inputs=[
        gr.Image(type="pil"),
        gr.ColorPicker(label="Background Color", value=None),
        gr.Dropdown(
            choices=[""] + [f"{k} | {v}" for k, v in MODEL_OPTIONS.items() if k != ""],
            label="Model Selection",
            value="",
            type="value"
        ),
        gr.Checkbox(label="Enable Alpha Matting", value=False),
        gr.Checkbox(label="Post-Process Mask (post process the mask to get better results)", value=False),
        gr.Checkbox(label="Only Return Mask ", value=False)
    ],
    outputs=[
        gr.Image(type="pil", label="Output Image"), 
        gr.File(label="Download the output image")
    ],
    examples=examples,
    title="Advanced Background Remover v2.7",
    description="Upload an image to remove the background. Customize the result with different options, including background color, model selection, alpha matting, and more.",
    allow_flagging="never",
)

if __name__ == "__main__":
    
    iface.launch()