File size: 2,311 Bytes
dd67556 3ccd256 dd67556 c384edc dd67556 c384edc dd67556 c384edc cffeaa2 52b892c cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc 815368c cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc cffeaa2 c384edc 3ccd256 815368c 3ccd256 815368c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from briarmbg import BriaRMBG
from huggingface_hub import hf_hub_download
import gradio as gr
from PIL import Image
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image):
# 이미지가 numpy 배열인 경우에만 PIL.Image 객체로 변환
if isinstance(image, np.ndarray):
orig_image = Image.fromarray(image)
else:
# 이미 PIL.Image.Image 객체인 경우, 변환 없이 사용
orig_image = image
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# 모델 로딩 및 예측
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
css = """
footer {
visibility: hidden;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column():
input_image = gr.Image(type="pil")
output_image = gr.Image()
process_button = gr.Button("Remove Background Image")
process_button.click(fn=process, inputs=input_image, outputs=output_image)
demo.launch()
|