clearbg_space / app.py
Aryan Wadhawan
Add model
96da21c
raw
history blame
3.27 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os
from u2net import U2NET
import data_transforms
import torch.nn.functional as F
from skimage import io
from torchvision.transforms.functional import normalize
# Load the model
model = U2NET(3, 1)
model_path = "u2net.pth"
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Preprocess the image
def preprocess(image):
label_3 = np.zeros(image.shape)
label = np.zeros(label_3.shape[0:2])
if 3 == len(label_3.shape):
label = label_3[:, :, 0]
elif 2 == len(label_3.shape):
label = label_3
if 3 == len(image.shape) and 2 == len(label.shape):
label = label[:, :, np.newaxis]
elif 2 == len(image.shape) and 2 == len(label.shape):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]
transform = transforms.Compose([data_transforms.RescaleT(320), data_transforms.ToTensorLab(flag=0)])
sample = transform({"imidx": np.array([0]), "image": image, "label": label})
return sample
# Generate the mask
def generate_mask(image):
# Preprocess the image
image = np.array(image.convert("RGB"))
img = preprocess(image)
input_size = [1024, 1024]
im_shp = image.shape[0:2]
im_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)
# Replace F.upsample with F.interpolate
im_tensor = F.interpolate(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
with torch.no_grad():
result = model(image)
result = torch.squeeze(F.interpolate(result[0][0], im_shp, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
result = result.numpy()
output_mask = result[0]
output_mask = (output_mask - output_mask.min()) / (output_mask.max() - output_mask.min()) * 255
output_mask = output_mask.astype(np.uint8)
return output_mask
# Define the final predict method to overlay the mask
def predict(image):
# Generate the mask
mask = generate_mask(image)
# Convert the image to RGBA (to support transparency)
image = image.convert("RGBA")
# Convert the mask into a binary mask where 255 is kept and 0 is transparent
mask = Image.fromarray(mask).resize(image.size).convert("L") # Convert to grayscale (L mode)
# Create a new image with transparency (RGBA)
transparent_image = Image.new("RGBA", image.size)
# Use the mask as transparency: paste the original image where the mask is white
transparent_image.paste(image, mask=mask)
return transparent_image
# Create the Gradio interface with custom output size for the display only (not affecting the saved image)
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil", label="Edited Image", image_mode="RGBA"), # RGBA ensures PNG with transparency
title="Background Removal with U2NET",
description="Upload an image and remove the background"
)
if __name__ == "__main__":
iface.launch()