|
import gradio as gr |
|
import torch |
|
from carvekit.api.interface import Interface |
|
from carvekit.ml.wrap.fba_matting import FBAMatting |
|
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 |
|
from carvekit.pipelines.postprocessing import MattingMethod |
|
from carvekit.pipelines.preprocessing import PreprocessingStub |
|
from carvekit.trimap.generator import TrimapGenerator |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
seg_net = TracerUniversalB7(device=device, batch_size=1) |
|
|
|
fba = FBAMatting(device=device, |
|
input_tensor_size=2048, |
|
batch_size=1) |
|
|
|
trimap = TrimapGenerator() |
|
|
|
preprocessing = PreprocessingStub() |
|
|
|
postprocessing = MattingMethod(matting_module=fba, |
|
trimap_generator=trimap, |
|
device=device) |
|
|
|
interface = Interface(pre_pipe=preprocessing, |
|
post_pipe=postprocessing, |
|
seg_pipe=seg_net) |
|
|
|
|
|
def predict(image): |
|
return interface([image])[0] |
|
|
|
|
|
footer = r""" |
|
<center> |
|
<b> |
|
Demo based on <a href='https://github.com/OPHoperHPO/image-background-remove-tool'>CarveKit</a> |
|
</b> |
|
</center> |
|
""" |
|
|
|
with gr.Blocks(title="CarveKit") as app: |
|
gr.HTML("<center><h1>Image Remove Background</h1></center>") |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_img = gr.Image(type="pil", label="Input image") |
|
run_btn = gr.Button(variant="primary") |
|
with gr.Column(): |
|
output_img = gr.Image(type="pil", label="result") |
|
|
|
run_btn.click(predict, [input_img], [output_img]) |
|
|
|
with gr.Row(): |
|
examples_data = [[f"examples/{x:02d}.jpg"] for x in range(1, 4)] |
|
examples = gr.Dataset(components=[input_img], samples=examples_data) |
|
examples.click(lambda x: x[0], [examples], [input_img]) |
|
|
|
with gr.Row(): |
|
gr.HTML(footer) |
|
|
|
app.launch(share=False, debug=True, enable_queue=True, show_error=True) |
|
|