File size: 4,920 Bytes
dd67556 a119d24 dd67556 a119d24 dd67556 a119d24 dd67556 c384edc dd67556 c384edc dd67556 c384edc dd67556 c384edc dd67556 c384edc dd67556 52b892c c384edc cffeaa2 52b892c cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc a6a92bb c384edc a6a92bb c384edc 8f831fa c384edc 8f831fa c384edc 8f831fa c384edc 2686b72 8f831fa c384edc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
net=BriaRMBG()
# model_path = "./model1.pth"
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net=net.cuda()
else:
net.load_state_dict(torch.load(model_path,map_location="cpu"))
net.eval()
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# 이미지가 numpy 배열인 경우에만 PIL.Image 객체로 변환
if isinstance(image, np.ndarray):
orig_image = Image.fromarray(image)
else:
# 이미 PIL.Image.Image 객체인 경우, 변환 없이 사용
orig_image = image
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
def calculate_position(org_size, add_size, position):
if position == "상단 좌측":
return (0, 0)
elif position == "상단 가운데":
return ((org_size[0] - add_size[0]) // 2, 0)
elif position == "상단 우측":
return (org_size[0] - add_size[0], 0)
elif position == "중앙 좌측":
return (0, (org_size[1] - add_size[1]) // 2)
elif position == "중앙 가운데":
return ((org_size[0] - add_size[0]) // 2, (org_size[1] - add_size[1]) // 2)
elif position == "중앙 우측":
return (org_size[0] - add_size[0], (org_size[1] - add_size[1]) // 2)
elif position == "하단 좌측":
return (0, org_size[1] - add_size[1])
elif position == "하단 가운데":
return ((org_size[0] - add_size[0]) // 2, org_size[1] - add_size[1])
elif position == "하단 우측":
return (org_size[0] - add_size[0], org_size[1] - add_size[1])
def merge(org_image, add_image, scale, position):
scale_percentage = scale / 100.0
new_size = (int(add_image.width * scale_percentage), int(add_image.height * scale_percentage))
add_image = add_image.resize(new_size, Image.Resampling.LANCZOS)
position = calculate_position(org_image.size, add_image.size, position)
org_image.paste(add_image, position, add_image)
return org_image
with gr.Blocks() as demo:
with gr.Tab("Background Removal"):
with gr.Column():
gr.Markdown("## BRIA RMBG 1.4")
gr.HTML('''
<p style="margin-bottom: 10px; font-size: 94%">
This is a demo for BRIA RMBG 1.4 that using
<a href="https://huggingface.co/briaai/RMBG-1.4" target="_blank">BRIA RMBG-1.4 image matting model</a> as backbone.
</p>
''')
input_image = gr.Image(type="pil")
output_image = gr.Image()
process_button = gr.Button("Remove Background")
process_button.click(fn=process, inputs=input_image, outputs=output_image)
with gr.Tab("Merge"):
with gr.Column():
org_image = gr.Image(label="Background", type='pil', image_mode='RGBA', height="40vh")
add_image = gr.Image(label="Foreground", type='pil', image_mode='RGBA', height="40vh")
scale = gr.Slider(minimum=10, maximum=200, step=1, value=100, label="Scale of Foreground Image (%)")
position = gr.Radio(choices=["중앙 가운데", "상단 좌측", "상단 가운데", "상단 우측", "중앙 좌측", "중앙 우측", "하단 좌측", "하단 가운데", "하단 우측"], value="중앙 가운데", label="Position of Foreground Image")
btn_merge = gr.Button("Merge Images")
result_merge = gr.Image(height="40vh")
btn_merge.click(fn=merge, inputs=[org_image, add_image, scale, position], outputs=result_merge)
demo.launch() |