Spaces:
Running
Running
File size: 4,292 Bytes
67d263e def3395 67d263e def3395 67d263e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# demo source from: https://github.com/MarcoForte/FBA_Matting
import cv2
import gradio as gr
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from networks.models import build_model
from networks.transforms import trimap_transform, normalise_image
REPO_ID = "leonelhs/FBA-Matting"
weights = hf_hub_download(repo_id=REPO_ID, filename="FBA.pth")
model = build_model(weights)
model.eval().cpu()
def np_to_torch(x, permute=True):
if permute:
return torch.from_numpy(x).permute(2, 0, 1)[None, :, :, :].float().cpu()
else:
return torch.from_numpy(x)[None, :, :, :].float().cpu()
def scale_input(x: np.ndarray, scale: float, scale_type) -> np.ndarray:
''' Scales inputs to multiple of 8. '''
h, w = x.shape[:2]
h1 = int(np.ceil(scale * h / 8) * 8)
w1 = int(np.ceil(scale * w / 8) * 8)
x_scale = cv2.resize(x, (w1, h1), interpolation=scale_type)
return x_scale
def inference(image_np: np.ndarray, trimap_np: np.ndarray) -> [np.ndarray]:
''' Predict alpha, foreground and background.
Parameters:
image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
Returns:
fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
'''
h, w = trimap_np.shape[:2]
image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)
with torch.no_grad():
image_torch = np_to_torch(image_scale_np)
trimap_torch = np_to_torch(trimap_scale_np)
trimap_transformed_torch = np_to_torch(
trimap_transform(trimap_scale_np), permute=False)
image_transformed_torch = normalise_image(
image_torch.clone())
output = model(
image_torch,
trimap_torch,
image_transformed_torch,
trimap_transformed_torch)
output = cv2.resize(
output[0].cpu().numpy().transpose(
(1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
alpha = output[:, :, 0]
fg = output[:, :, 1:4]
bg = output[:, :, 4:7]
alpha[trimap_np[:, :, 0] == 1] = 0
alpha[trimap_np[:, :, 1] == 1] = 1
fg[alpha == 1] = image_np[alpha == 1]
bg[alpha == 0] = image_np[alpha == 0]
return fg, bg, alpha
def read_image(name):
return (cv2.imread(name) / 255.0)[:, :, ::-1]
def read_trimap(name):
trimap_im = cv2.imread(name, 0) / 255.0
h, w = trimap_im.shape
trimap_np = np.zeros((h, w, 2))
trimap_np[trimap_im == 1, 1] = 1
trimap_np[trimap_im == 0, 0] = 1
return trimap_np
def predict(image, trimap):
image_np = read_image(image)
trimap_np = read_trimap(trimap)
return inference(image_np, trimap_np)
footer = r"""
<center>
<b>
Demo for <a href='https://github.com/MarcoForte/FBA_Matting'>FBA Matting</a>
</b>
</center>
"""
with gr.Blocks(title="FBA Matting") as app:
gr.HTML("<center><h1>FBA Matting</h1></center>")
gr.HTML("<center><h3>Foreground, Background, Alpha Matting Generator.</h3></center>")
with gr.Row().style(equal_height=False):
with gr.Column():
input_img = gr.Image(type="filepath", label="Input image")
input_trimap = gr.Image(type="filepath", label="Trimap image")
run_btn = gr.Button(variant="primary")
with gr.Column():
fg = gr.Image(type="numpy", label="Foreground")
bg = gr.Image(type="numpy", label="Background")
alpha = gr.Image(type="numpy", label="Alpha")
run_btn.click(predict, [input_img, input_trimap], [fg, bg, alpha])
with gr.Row():
blobs = [[
f"./images/{x:02d}.png",
f"./trimaps/{x:02d}.png"] for x in range(1, 7)]
examples = gr.Dataset(components=[input_img, input_trimap], samples=blobs)
examples.click(lambda x: (x[0], x[1]), [examples], [input_img, input_trimap])
with gr.Row():
gr.HTML(footer)
app.launch(share=False, debug=True, enable_queue=True, show_error=True)
|