carvekit / app.py
leonelhs's picture
add trimap generator
1c39d4f
raw
history blame
2.93 kB
import gradio as gr
import torch
from carvekit.api.interface import Interface
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.postprocessing import MattingMethod
from carvekit.pipelines.preprocessing import PreprocessingStub
from carvekit.trimap.generator import TrimapGenerator
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Check doc strings for more information
seg_net = TracerUniversalB7(device=device, batch_size=1)
fba = FBAMatting(device=device,
input_tensor_size=2048,
batch_size=1)
trimap = TrimapGenerator()
preprocessing = PreprocessingStub()
postprocessing = MattingMethod(matting_module=fba,
trimap_generator=trimap,
device=device)
interface = Interface(pre_pipe=preprocessing,
post_pipe=postprocessing,
seg_pipe=seg_net)
def generate_trimap(original):
mask = seg_net([original])
return trimap(original_image=original, mask=mask[0])
def predict(image):
return interface([image])[0]
footer = r"""
<center>
<img src='https://raw.githubusercontent.com/leonelhs/image-background-remove-tool/master/docs/imgs/logo.png' alt='CarveKit' width="200" height="80">
</br>
<b>
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a>
</b>
</center>
"""
with gr.Blocks(title="CarveKit") as app:
gr.Markdown("<center><h1><b>CarveKit</b></h1></center>")
gr.HTML("<center><h3>High-quality image background removal</h3></center>")
with gr.Tabs() as tabs:
with gr.TabItem("Remove background", id=0):
with gr.Row().style(equal_height=False):
with gr.Column():
input_img = gr.Image(type="pil", label="Input image")
run_btn = gr.Button(variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="result")
run_btn.click(predict, [input_img], [output_img])
with gr.TabItem("Generate trimap", id=1):
with gr.Row().style(equal_height=False):
with gr.Column():
trimap_input = gr.Image(type="pil", label="Input image")
trimap_btn = gr.Button(variant="primary")
with gr.Column():
trimap_output = gr.Image(type="pil", label="result")
trimap_btn.click(generate_trimap, [trimap_input], [trimap_output])
# with gr.Row():
# examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)]
# examples = gr.Dataset(components=[input_img], samples=examples_data)
# examples.click(lambda x: x[0], [examples], [input_img])
with gr.Row():
gr.HTML(footer)
app.launch(share=False, debug=True, enable_queue=True, show_error=True)