File size: 3,877 Bytes
47f51f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
#import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import numpy as np
from PIL import Image

# 检查 CUDA 是否可用
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to(device)
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

output_folder = 'output_images'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# 定义颜色列表,每个颜色对应一个 mask
colors = [
    '#000000',  # 背景色
    '#2692F3',  # 蓝色
    '#F89E12',  # 橙色
    '#16C232',  # 绿色
    '#F92F6C',  # 粉色
    '#AC6AEB',  # 紫色
]

# 将颜色转换为 RGB 值
palette = np.array([
    tuple(int(s[i + 1:i + 3], 16) for i in (0, 2, 4))
    for s in colors[1:]  # 跳过背景色
])  # (N, 3)

def fn(image, mask_color):
    im = load_img(image, output_type="pil")
    im = im.convert("RGB")
    origin = im.copy()
    image, mask = process(im, mask_color)    
    image_path = os.path.join(output_folder, "no_bg_image.png")
    mask_path = os.path.join(output_folder, "mask_image.png")
    image.save(image_path)
    mask.save(mask_path)
    return (image, origin), image_path, mask

#@spaces.GPU
def process(image, mask_color):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to(device)
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)
    
    # 创建一个新的透明背景图像
    transparent_image = Image.new("RGBA", image_size, (0, 0, 0, 0))
    transparent_image.paste(image, (0, 0), mask)
    
    # 创建一个带有颜色的 mask 图像
    mask_color_rgb = tuple(int(mask_color[i + 1:i + 3], 16) for i in (0, 2, 4))
    colored_mask = Image.new("RGBA", image_size, mask_color_rgb + (255,))
    colored_mask.putalpha(mask)
    
    return transparent_image, colored_mask

# 示例数据
example_image = "giraffe.jpg"  # 确保该文件存在于当前目录
example_url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"

# 定义 Gradio 组件
with gr.Blocks() as demo:
    gr.Markdown("# 🖼️ RMBG-2.0 for Background Removal")
    with gr.Row():
        # 左侧列:输入
        with gr.Column():
            gr.Markdown("## Input")
            image_input = gr.Image(label="Upload an image")
            text_input = gr.Textbox(label="Paste an image URL")
            color_input = gr.Dropdown(label="Mask Color", choices=colors[1:], value=colors[1])
            run_button = gr.Button("Run")

        # 右侧列:输出
        with gr.Column():
            gr.Markdown("## Output")
            slider_output = ImageSlider(label="RMBG-2.0", type="pil")
            file_output = gr.File(label="Output PNG File")
            mask_output = gr.Image(label="Mask Image")

    # 示例数据
    gr.Examples(
        examples=[[example_image, colors[1]], [example_url, colors[1]]],
        inputs=[image_input, color_input],
        outputs=[slider_output, file_output, mask_output],  # 添加 outputs 参数
        fn=fn,
        cache_examples=True
    )

    # 绑定事件
    run_button.click(
        fn=fn,
        inputs=[image_input, color_input],
        outputs=[slider_output, file_output, mask_output]
    )

if __name__ == "__main__":
    demo.launch(share=True, show_error=True)