import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize from huggingface_hub import hf_hub_download import gradio as gr from gradio_imageslider import ImageSlider from briarmbg import BriaRMBG import PIL from PIL import Image from typing import Tuple # 모델 초기화 및 로드 net = BriaRMBG() model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth') if torch.cuda.is_available(): net.load_state_dict(torch.load(model_path)) net = net.cuda() else: net.load_state_dict(torch.load(model_path, map_location="cpu")) net.eval() def resize_image(image, model_input_size=(1024, 1024)): image = image.convert('RGB') image = image.resize(model_input_size, Image.BILINEAR) return image def process(image, background_image=None): # 이미지 준비 orig_image = Image.fromarray(image).convert("RGB") w, h = orig_im_size = orig_image.size image = resize_image(orig_image) im_np = np.array(image) im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0) / 255.0 im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) if torch.cuda.is_available(): im_tensor = im_tensor.cuda() # 추론 with torch.no_grad(): result = net(im_tensor) # 후처리 result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0) result = torch.sigmoid(result) mask = (result * 255).byte().cpu().numpy() # mask 배열이 예상대로 2차원인지 확인하고, 아니라면 조정 if mask.ndim > 2: mask = mask.squeeze() # mask 배열을 명확히 uint8로 변환 mask = mask.astype(np.uint8) # 마스크를 알파 채널로 사용하여 최종 이미지 생성 orig_image = orig_image.convert("RGBA") final_image = Image.new("RGBA", orig_image.size, (0, 0, 0, 0)) mask_image = Image.fromarray(mask, mode='L') foreground_image = Image.composite(orig_image, final_image, mask_image) # 선택적 배경 이미지 처리 if background_image: final_image = merge_images(background_image, foreground_image) else: final_image = foreground_image return final_image def merge_images(background_image, foreground_image): """ 배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다. 배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다. """ background = background_image.convert("RGBA") foreground = foreground_image.convert("RGBA") # 전경 이미지를 배경 이미지의 30% 크기로 조정 scale_factor = 0.3 foreground_width = int(background.width * scale_factor) foreground_height = int(foreground.height * foreground_width / foreground.width) new_size = (foreground_width, foreground_height) foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS) # 전경 이미지를 배경 이미지의 가운데에 위치시키기 위한 좌표 계산 x = (background.width - foreground_width) // 2 y = (background.height - foreground_height) // 2 # 배경 이미지 위에 전경 이미지를 붙임 background.paste(foreground_resized, (x, y), foreground_resized) return background title = "Background Removal" description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone." demo = gr.Interface( fn=process, inputs=[ gr.Image(type="numpy", label="Image to remove background"), gr.Image(type="pil", label="Optional: Background Image") # 'optional=True'는 Gradio 버전에 따라 필요 없을 수 있음 ], outputs="image", title=title, description=description ) if __name__ == "__main__": demo.launch()