File size: 3,435 Bytes
2012f74
 
 
699e9ad
 
2012f74
 
699e9ad
2012f74
 
 
 
 
 
699e9ad
 
 
 
 
 
2012f74
 
 
 
 
 
 
 
 
 
 
 
 
699e9ad
2012f74
 
699e9ad
 
1c39d4f
001870b
 
699e9ad
 
 
 
 
2012f74
 
 
 
001870b
 
2012f74
 
 
 
 
 
 
1c39d4f
 
 
 
 
699e9ad
1c39d4f
 
699e9ad
 
 
 
1c39d4f
 
 
 
699e9ad
1c39d4f
03bfe96
699e9ad
1c39d4f
 
699e9ad
 
 
 
1c39d4f
 
 
 
699e9ad
2012f74
 
 
 
4d03b45
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
from carvekit.api.interface import Interface
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.ml.wrap.u2net import U2NET
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator

device = 'cuda' if torch.cuda.is_available() else 'cpu'

segment_net = {
    "U2NET": U2NET(device=device, batch_size=1),
    "BASNET": BASNET(device=device, batch_size=1),
    "DeepLabV3": DeepLabV3(device=device, batch_size=1),
    "TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1)
}

fba = FBAMatting(device=device,
                 input_tensor_size=2048,
                 batch_size=1)

trimap = TrimapGenerator()

preprocessing = PreprocessingStub()

postprocessing = MattingMethod(matting_module=fba,
                               trimap_generator=trimap,
                               device=device)

method_choices = [k for k, v in segment_net.items()]


def generate_trimap(method, original):
    mask = segment_net[method]([original])
    return trimap(original_image=original, mask=mask[0])


def predict(method, image):
    method = segment_net[method]
    return Interface(pre_pipe=preprocessing,
                     post_pipe=postprocessing,
                     seg_pipe=method)([image])[0]


footer = r"""
<center>
<img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80">
</br>
<b>
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a>
</b>
</center>
"""

with gr.Blocks(title="CarveKit") as app:
    gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
    gr.HTML("<center><h3>High-quality image background removal</h3></center>")

    with gr.Tabs() as tabs:
        with gr.TabItem("Remove background", id=0):
            with gr.Row(equal_height=False):
                with gr.Column():
                    input_img = gr.Image(type="pil", label="Input image")
                    drp_itf = gr.Dropdown(
                        value="TracerUniversalB7",
                        label="Segmentor model",
                        choices=method_choices)
                    run_btn = gr.Button(variant="primary")
                with gr.Column():
                    output_img = gr.Image(type="pil", label="result")

            run_btn.click(predict, [drp_itf, input_img], [output_img])

        with gr.TabItem("Trimap generator", id=1):
            with gr.Row(equal_height=False):
                with gr.Column():
                    trimap_input = gr.Image(type="pil", label="Input image")
                    drp_itf = gr.Dropdown(
                        value="TracerUniversalB7",
                        label="Segmentor model",
                        choices=method_choices)
                    trimap_btn = gr.Button(variant="primary")
                with gr.Column():
                    trimap_output = gr.Image(type="pil", label="result")

            trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output])

    with gr.Row():
        gr.HTML(footer)

app.queue()
app.launch(share=False, debug=True, show_error=True)