File size: 4,236 Bytes
dd67556 a119d24 dd67556 a119d24 dd67556 a119d24 dd67556 fc9e8ce dd67556 fc9e8ce dd67556 fc9e8ce dd67556 fc9e8ce dd67556 52b892c 5801493 e84892b 768a2fd 52b892c fc9e8ce dd67556 fc9e8ce dc146a8 e84892b fc9e8ce dc146a8 52b892c a119d24 dc146a8 e84892b 768a2fd a6a92bb a119d24 a6a92bb 768a2fd a119d24 768a2fd dd67556 e84892b ca70732 a119d24 52b892c a6a92bb e84892b 597ebe6 e84892b d508478 39590b8 d508478 39590b8 d508478 39590b8 d508478 e84892b 39590b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# 모델 초기화 및 로드
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image, model_input_size=(1024, 1024)):
image = image.convert('RGB')
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image, background_image=None):
# 이미지 준비
orig_image = Image.fromarray(image).convert("RGB")
w, h = orig_image.size
resized_image = resize_image(orig_image)
im_np = np.array(resized_image).astype(np.float32) / 255.0
im_tensor = torch.tensor(im_np).permute(2, 0, 1).unsqueeze(0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# 추론
with torch.no_grad():
result = net(im_tensor)
# 후처리
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
result = torch.sigmoid(result)
mask = (result * 255).byte().cpu().numpy()
if mask.ndim > 2:
mask = mask.squeeze()
mask = mask.astype(np.uint8)
# 마스크를 알파 채널로 사용하여 최종 이미지 생성
final_image = Image.new("RGBA", orig_image.size)
orig_image.putalpha(Image.fromarray(mask, 'L'))
if background_image:
# 배경 이미지가 제공된 경우, 배경 이미지 크기 조정
background = background_image.convert("RGBA").resize(orig_image.size)
# 배경과 전경(알파 적용된 원본 이미지) 합성
final_image = Image.alpha_composite(background, orig_image)
else:
# 배경 이미지가 없는 경우, 투명도가 적용된 원본 이미지를 최종 이미지로 사용
final_image = orig_image
return final_image
def merge_images(background_image, foreground_image):
"""
배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다.
배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다.
"""
background = background_image.convert("RGBA")
foreground = foreground_image.convert("RGBA")
# 전경 이미지를 배경 이미지의 30% 크기로 조정
scale_factor = 0.3
foreground_width = int(background.width * scale_factor)
foreground_height = int(foreground.height * foreground_width / foreground.width)
new_size = (foreground_width, foreground_height)
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
# 전경 이미지를 배경 이미지의 가운데에 위치시키기 위한 좌표 계산
x = (background.width - foreground_width) // 2
y = (background.height - foreground_height) // 2
# 배경 이미지 위에 전경 이미지를 붙임
background.paste(foreground_resized, (x, y), foreground_resized)
return background
title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."
def add_blue_background(image):
# 배경 제거된 이미지에 푸른색 배경을 추가하는 함수
blue_background = Image.new("RGBA", image.size, "blue")
final_image = Image.alpha_composite(blue_background, image.convert("RGBA"))
return final_image
inputs = "image" # 이전: gr.inputs.Image(type="pil") -> 변경: "image"
outputs = ["image"] # 이전 방식에서 변경되었을 수 있는 출력 부분
demo = gr.Interface(fn=process,
inputs=inputs,
outputs=outputs,
title="Your Demo Title",
description="A brief description of your app.")
if __name__ == "__main__":
demo.launch()
|