|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision.transforms.functional import normalize |
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
from gradio_imageslider import ImageSlider |
|
from briarmbg import BriaRMBG |
|
import PIL |
|
from PIL import Image |
|
from typing import Tuple |
|
|
|
|
|
net = BriaRMBG() |
|
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') |
|
if torch.cuda.is_available(): |
|
net.load_state_dict(torch.load(model_path)) |
|
net = net.cuda() |
|
else: |
|
net.load_state_dict(torch.load(model_path, map_location="cpu")) |
|
net.eval() |
|
|
|
def resize_image(image, model_input_size=(1024, 1024)): |
|
image = image.convert('RGB') |
|
image = image.resize(model_input_size, Image.BILINEAR) |
|
return image |
|
|
|
|
|
def process(image, background_image=None): |
|
|
|
orig_image = Image.fromarray(image).convert("RGB") |
|
w, h = orig_image.size |
|
resized_image = resize_image(orig_image) |
|
im_np = np.array(resized_image).astype(np.float32) / 255.0 |
|
im_tensor = torch.tensor(im_np).permute(2, 0, 1).unsqueeze(0) |
|
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) |
|
if torch.cuda.is_available(): |
|
im_tensor = im_tensor.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
result = net(im_tensor) |
|
|
|
|
|
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0) |
|
result = torch.sigmoid(result) |
|
mask = (result * 255).byte().cpu().numpy() |
|
|
|
if mask.ndim > 2: |
|
mask = mask.squeeze() |
|
|
|
mask = mask.astype(np.uint8) |
|
|
|
|
|
final_image = Image.new("RGBA", orig_image.size) |
|
orig_image.putalpha(Image.fromarray(mask, 'L')) |
|
|
|
if background_image: |
|
|
|
background = background_image.convert("RGBA").resize(orig_image.size) |
|
|
|
final_image = Image.alpha_composite(background, orig_image) |
|
else: |
|
|
|
final_image = orig_image |
|
|
|
return final_image |
|
|
|
|
|
|
|
def merge_images(background_image, foreground_image): |
|
""" |
|
배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다. |
|
배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다. |
|
""" |
|
background = background_image.convert("RGBA") |
|
foreground = foreground_image.convert("RGBA") |
|
|
|
|
|
scale_factor = 0.3 |
|
foreground_width = int(background.width * scale_factor) |
|
foreground_height = int(foreground.height * foreground_width / foreground.width) |
|
new_size = (foreground_width, foreground_height) |
|
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
x = (background.width - foreground_width) // 2 |
|
y = (background.height - foreground_height) // 2 |
|
|
|
|
|
background.paste(foreground_resized, (x, y), foreground_resized) |
|
|
|
return background |
|
|
|
|
|
title = "Background Removal" |
|
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone." |
|
|
|
|
|
def add_blue_background(image): |
|
|
|
blue_background = Image.new("RGBA", image.size, "blue") |
|
final_image = Image.alpha_composite(blue_background, image.convert("RGBA")) |
|
return final_image |
|
|
|
|
|
inputs = gr.inputs.Image(type="pil") |
|
outputs = [gr.outputs.Image(type="pil", label="Background Removed"), |
|
gr.outputs.Image(type="pil", label="Blue Background Added")] |
|
|
|
def update_interface(image): |
|
transparent_image = process(image) |
|
blue_background_image = add_blue_background(transparent_image) |
|
return transparent_image, blue_background_image |
|
|
|
demo = gr.Interface(fn=update_interface, |
|
inputs=inputs, |
|
outputs=outputs, |
|
title="Background Removal and Add Blue Background", |
|
description="This adds a blue background to the transparent image.") |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |