File size: 3,435 Bytes
2012f74 699e9ad 2012f74 699e9ad 2012f74 699e9ad 2012f74 699e9ad 2012f74 699e9ad 1c39d4f 001870b 699e9ad 2012f74 001870b 2012f74 1c39d4f 699e9ad 1c39d4f 699e9ad 1c39d4f 699e9ad 1c39d4f 03bfe96 699e9ad 1c39d4f 699e9ad 1c39d4f 699e9ad 2012f74 4d03b45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import torch
from carvekit.api.interface import Interface
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.ml.wrap.u2net import U2NET
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator
device = 'cuda' if torch.cuda.is_available() else 'cpu'
segment_net = {
"U2NET": U2NET(device=device, batch_size=1),
"BASNET": BASNET(device=device, batch_size=1),
"DeepLabV3": DeepLabV3(device=device, batch_size=1),
"TracerUniversalB7": TracerUniversalB7(device=device, batch_size=1)
}
fba = FBAMatting(device=device,
input_tensor_size=2048,
batch_size=1)
trimap = TrimapGenerator()
preprocessing = PreprocessingStub()
postprocessing = MattingMethod(matting_module=fba,
trimap_generator=trimap,
device=device)
method_choices = [k for k, v in segment_net.items()]
def generate_trimap(method, original):
mask = segment_net[method]([original])
return trimap(original_image=original, mask=mask[0])
def predict(method, image):
method = segment_net[method]
return Interface(pre_pipe=preprocessing,
post_pipe=postprocessing,
seg_pipe=method)([image])[0]
footer = r"""
<center>
<img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80">
</br>
<b>
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a>
</b>
</center>
"""
with gr.Blocks(title="CarveKit") as app:
gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
gr.HTML("<center><h3>High-quality image background removal</h3></center>")
with gr.Tabs() as tabs:
with gr.TabItem("Remove background", id=0):
with gr.Row(equal_height=False):
with gr.Column():
input_img = gr.Image(type="pil", label="Input image")
drp_itf = gr.Dropdown(
value="TracerUniversalB7",
label="Segmentor model",
choices=method_choices)
run_btn = gr.Button(variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="result")
run_btn.click(predict, [drp_itf, input_img], [output_img])
with gr.TabItem("Trimap generator", id=1):
with gr.Row(equal_height=False):
with gr.Column():
trimap_input = gr.Image(type="pil", label="Input image")
drp_itf = gr.Dropdown(
value="TracerUniversalB7",
label="Segmentor model",
choices=method_choices)
trimap_btn = gr.Button(variant="primary")
with gr.Column():
trimap_output = gr.Image(type="pil", label="result")
trimap_btn.click(generate_trimap, [drp_itf, trimap_input], [trimap_output])
with gr.Row():
gr.HTML(footer)
app.queue()
app.launch(share=False, debug=True, show_error=True)
|