Spaces:
fantos
/
Runtime error

nuking / app.py
arxivgpt kim
Update app.py
d508478 verified
raw
history blame
4.67 kB
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# 모델 초기화 및 로드
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image, model_input_size=(1024, 1024)):
image = image.convert('RGB')
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image, background_image=None):
# 이미지 준비
orig_image = Image.fromarray(image).convert("RGB")
w, h = orig_image.size
resized_image = resize_image(orig_image)
im_np = np.array(resized_image).astype(np.float32) / 255.0
im_tensor = torch.tensor(im_np).permute(2, 0, 1).unsqueeze(0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# 추론
with torch.no_grad():
result = net(im_tensor)
# 후처리
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
result = torch.sigmoid(result)
mask = (result * 255).byte().cpu().numpy()
if mask.ndim > 2:
mask = mask.squeeze()
mask = mask.astype(np.uint8)
# 마스크를 알파 채널로 사용하여 최종 이미지 생성
final_image = Image.new("RGBA", orig_image.size)
orig_image.putalpha(Image.fromarray(mask, 'L'))
if background_image:
# 배경 이미지가 제공된 경우, 배경 이미지 크기 조정
background = background_image.convert("RGBA").resize(orig_image.size)
# 배경과 전경(알파 적용된 원본 이미지) 합성
final_image = Image.alpha_composite(background, orig_image)
else:
# 배경 이미지가 없는 경우, 투명도가 적용된 원본 이미지를 최종 이미지로 사용
final_image = orig_image
return final_image
def merge_images(background_image, foreground_image):
"""
배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다.
배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다.
"""
background = background_image.convert("RGBA")
foreground = foreground_image.convert("RGBA")
# 전경 이미지를 배경 이미지의 30% 크기로 조정
scale_factor = 0.3
foreground_width = int(background.width * scale_factor)
foreground_height = int(foreground.height * foreground_width / foreground.width)
new_size = (foreground_width, foreground_height)
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
# 전경 이미지를 배경 이미지의 가운데에 위치시키기 위한 좌표 계산
x = (background.width - foreground_width) // 2
y = (background.height - foreground_height) // 2
# 배경 이미지 위에 전경 이미지를 붙임
background.paste(foreground_resized, (x, y), foreground_resized)
return background
title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."
def add_blue_background(image):
# 배경 제거된 이미지에 푸른색 배경을 추가하는 함수
blue_background = Image.new("RGBA", image.size, "blue")
final_image = Image.alpha_composite(blue_background, image.convert("RGBA"))
return final_image
# Gradio 인터페이스에서 이미지 입력과 두 개의 출력을 설정합니다.
inputs = gr.inputs.Image(type="pil")
outputs = [gr.outputs.Image(type="pil", label="Background Removed"),
gr.outputs.Image(type="pil", label="Blue Background Added")]
def update_interface(image):
transparent_image = process(image)
blue_background_image = add_blue_background(transparent_image)
return transparent_image, blue_background_image
demo = gr.Interface(fn=update_interface,
inputs=inputs,
outputs=outputs,
title="Background Removal and Add Blue Background",
description="This adds a blue background to the transparent image.")
# 'add' 버튼을 제거했으나, 백그라운드 추가는 자동으로 진행됩니다.
if __name__ == "__main__":
demo.launch()