import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
from huggingface_hub import hf_hub_download
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple

# 모델 초기화 및 로드
net = BriaRMBG()
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
if torch.cuda.is_available():
    net.load_state_dict(torch.load(model_path))
    net = net.cuda()
else:
    net.load_state_dict(torch.load(model_path, map_location="cpu"))
net.eval()

def resize_image(image, model_input_size=(1024, 1024)):
    image = image.convert('RGB')
    image = image.resize(model_input_size, Image.BILINEAR)
    return image


def process(image, background_image=None):
    # 이미지 준비
    orig_image = Image.fromarray(image).convert("RGB")
    w, h = orig_image.size
    resized_image = resize_image(orig_image)
    im_np = np.array(resized_image).astype(np.float32) / 255.0
    im_tensor = torch.tensor(im_np).permute(2, 0, 1).unsqueeze(0)
    im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
    if torch.cuda.is_available():
        im_tensor = im_tensor.cuda()

    # 추론
    with torch.no_grad():
        result = net(im_tensor)

    # 후처리
    result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear', align_corners=False), 0)
    result = torch.sigmoid(result)
    mask = (result * 255).byte().cpu().numpy()

    if mask.ndim > 2:
        mask = mask.squeeze()

    mask = mask.astype(np.uint8)

    # 마스크를 알파 채널로 사용하여 최종 이미지 생성
    final_image = Image.new("RGBA", orig_image.size)
    orig_image.putalpha(Image.fromarray(mask, 'L'))

    if background_image:
        # 배경 이미지가 제공된 경우, 배경 이미지 크기 조정
        background = background_image.convert("RGBA").resize(orig_image.size)
        # 배경과 전경(알파 적용된 원본 이미지) 합성
        final_image = Image.alpha_composite(background, orig_image)
    else:
        # 배경 이미지가 없는 경우, 투명도가 적용된 원본 이미지를 최종 이미지로 사용
        final_image = orig_image

    return final_image



def merge_images(background_image, foreground_image):
    """
    배경 이미지에 배경이 제거된 이미지를 투명하게 삽입합니다.
    배경이 제거된 이미지는 배경 이미지 중앙에 30% 크기로 축소되어 삽입됩니다.
    """
    background = background_image.convert("RGBA")
    foreground = foreground_image.convert("RGBA")
    
    # 전경 이미지를 배경 이미지의 30% 크기로 조정
    scale_factor = 0.3
    foreground_width = int(background.width * scale_factor)
    foreground_height = int(foreground.height * foreground_width / foreground.width)
    new_size = (foreground_width, foreground_height)
    foreground_resized = foreground.resize(new_size, Image.Resampling.LANCZOS)
    
    # 전경 이미지를 배경 이미지의 가운데에 위치시키기 위한 좌표 계산
    x = (background.width - foreground_width) // 2
    y = (background.height - foreground_height) // 2
    
    # 배경 이미지 위에 전경 이미지를 붙임
    background.paste(foreground_resized, (x, y), foreground_resized)
    
    return background


title = "Background Removal"
description = "This is a demo for BRIA RMBG 1.4 using the BRIA RMBG-1.4 image matting model as backbone."


def add_blue_background(image):
    # 배경 제거된 이미지에 푸른색 배경을 추가하는 함수
    blue_background = Image.new("RGBA", image.size, "blue")
    final_image = Image.alpha_composite(blue_background, image.convert("RGBA"))
    return final_image


inputs = "image"      # 이전: gr.inputs.Image(type="pil") -> 변경: "image"
outputs = ["image"]   # 이전 방식에서 변경되었을 수 있는 출력 부분

demo = gr.Interface(fn=process,
                    inputs=inputs,
                    outputs=outputs,
                    title="Your Demo Title",
                    description="A brief description of your app.")

if __name__ == "__main__":
    demo.launch()