File size: 2,928 Bytes
2012f74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c39d4f
 
 
001870b
 
2012f74
 
 
 
 
 
001870b
 
2012f74
 
 
 
 
 
 
 
1c39d4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2012f74
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import torch
from carvekit.api.interface import Interface
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Check doc strings for more information
seg_net = TracerUniversalB7(device=device, batch_size=1)

fba = FBAMatting(device=device,
                 input_tensor_size=2048,
                 batch_size=1)

trimap = TrimapGenerator()

preprocessing = PreprocessingStub()

postprocessing = MattingMethod(matting_module=fba,
                               trimap_generator=trimap,
                               device=device)

interface = Interface(pre_pipe=preprocessing,
                      post_pipe=postprocessing,
                      seg_pipe=seg_net)


def generate_trimap(original):
    mask = seg_net([original])
    return trimap(original_image=original, mask=mask[0])


def predict(image):
    return interface([image])[0]


footer = r"""
<center>
<img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80">
</br>
<b>
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a>
</b>
</center>
"""

with gr.Blocks(title="CarveKit") as app:

    gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
    gr.HTML("<center><h3>High-quality image background removal</h3></center>")

    with gr.Tabs() as tabs:
        with gr.TabItem("Remove background", id=0):
            with gr.Row().style(equal_height=False):
                with gr.Column():
                    input_img = gr.Image(type="pil", label="Input image")
                    run_btn = gr.Button(variant="primary")
                with gr.Column():
                    output_img = gr.Image(type="pil", label="result")

            run_btn.click(predict, [input_img], [output_img])

        with gr.TabItem("Generate trimap", id=1):
            with gr.Row().style(equal_height=False):
                with gr.Column():
                    trimap_input = gr.Image(type="pil", label="Input image")
                    trimap_btn = gr.Button(variant="primary")
                with gr.Column():
                    trimap_output = gr.Image(type="pil", label="result")

            trimap_btn.click(generate_trimap, [trimap_input], [trimap_output])

    # with gr.Row():
    #     examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
    #     examples = gr.Dataset(components=[input_img], samples=examples_data)
    #     examples.click(lambda x: x[0], [examples], [input_img])

    with gr.Row():
        gr.HTML(footer)

app.launch(share=False, debug=True, enable_queue=True, show_error=True)