Spaces:
fantos
/
Runtime error

nuking / app.py
arxivgpt kim
Update app.py
a119d24 verified
raw
history blame
3.88 kB
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๋กœ๋“œ
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
net.load_state_dict(torch.load(model_path))
net = net.cuda()
else:
net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()
def resize_image(image, model_input_size=(1024, 1024)):
image = image.convert('RGB')
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image, background_image=None):
# ์ด๋ฏธ์ง€ ์ค€๋น„
orig_image = Image.fromarray(image).convert("RGB")
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# ์ถ”๋ก 
with torch.no_grad():
result = net(im_tensor)
# ํ›„์ฒ˜๋ฆฌ
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
result = torch.sigmoid(result)
mask = (result * 255).byte().cpu().numpy()
# mask ๋ฐฐ์—ด์ด ์˜ˆ์ƒ๋Œ€๋กœ 2์ฐจ์›์ธ์ง€ ํ™•์ธํ•˜๊ณ , ์•„๋‹ˆ๋ผ๋ฉด ์กฐ์ •
if mask.ndim > 2:
mask = mask.squeeze()
# mask ๋ฐฐ์—ด์„ ๋ช…ํ™•ํžˆ uint8๋กœ ๋ณ€ํ™˜
mask = mask.astype(np.uint8)
# ๋งˆ์Šคํฌ๋ฅผ ์•ŒํŒŒ ์ฑ„๋„๋กœ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ข… ์ด๋ฏธ์ง€ ์ƒ์„ฑ
orig_image = orig_image.convert("RGBA")
final_image = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
mask_image = Image.fromarray(mask, mode='L')
foreground_image = Image.composite(orig_image, final_image, mask_image)
# ์„ ํƒ์  ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ
if background_image:
final_image = merge_images(background_image, foreground_image)
else:
final_image = foreground_image
return final_image
def merge_images(background_image, foreground_image):
"""
๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์— ๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋ฅผ ํˆฌ๋ช…ํ•˜๊ฒŒ ์‚ฝ์ž…ํ•ฉ๋‹ˆ๋‹ค.
๋ฐฐ๊ฒฝ์ด ์ œ๊ฑฐ๋œ ์ด๋ฏธ์ง€๋Š” ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์ค‘์•™์— 30% ํฌ๊ธฐ๋กœ ์ถ•์†Œ๋˜์–ด ์‚ฝ์ž…๋ฉ๋‹ˆ๋‹ค.
"""
background = background_image.convert("RGBA")
foreground = foreground_image.convert("RGBA")
# ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์˜ 30% ํฌ๊ธฐ๋กœ ์กฐ์ •
scale_factor = 0.3
foreground_width = int(background.width * scale_factor)
foreground_height = int(foreground.height * foreground_width / foreground.width)
new_size = (foreground_width, foreground_height)
foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
# ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€์˜ ๊ฐ€์šด๋ฐ์— ์œ„์น˜์‹œํ‚ค๊ธฐ ์œ„ํ•œ ์ขŒํ‘œ ๊ณ„์‚ฐ
x = (background.width - foreground_width) // 2
y = (background.height - foreground_height) // 2
# ๋ฐฐ๊ฒฝ ์ด๋ฏธ์ง€ ์œ„์— ์ „๊ฒฝ ์ด๋ฏธ์ง€๋ฅผ ๋ถ™์ž„
background.paste(foreground_resized, (x, y), foreground_resized)
return background
title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."
demo = gr.Interface(
fn=process,
inputs=[
gr.Image(type="numpy", label="Image to remove background"),
gr.Image(type="pil", label="Optional: Background Image") # 'optional=True'๋Š” Gradio ๋ฒ„์ „์— ๋”ฐ๋ผ ํ•„์š” ์—†์„ ์ˆ˜ ์žˆ์Œ
],
outputs="image",
title=title,
description=description
)
if __name__ == "__main__":
demo.launch()