Spaces:
Running
Running
File size: 3,068 Bytes
542c815 3f8e328 542c815 d6e753e 8a357d1 487aa58 3267028 018621a 3267028 b98efed 542c815 988f91c 542c815 04ba376 487aa58 04ba376 487aa58 04ba376 487aa58 542c815 487aa58 04ba376 70974c3 04ba376 542c815 04ba376 d909bca c530952 6ca28a8 4941fcb 6ca28a8 d909bca 487aa58 5682b91 487aa58 d909bca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
import requests
from io import BytesIO
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
def resize_image(image):
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def get_url_image(url):
headers = {'User-Agent': 'gradio-app'}
response = requests.get(url, headers=headers)
return BytesIO(response.content)
def load_image(image_source):
if isinstance(image_source, str): # Check if input is a URL
print(f"Loading image from URL: {image_source}")
image = Image.open(get_url_image(image_source))
else:
print("Loading image from file upload")
image = Image.fromarray(image_source)
return image
def process(image_source):
try:
# Load and prepare input
orig_image = load_image(image_source)
w, h = orig_im_size = orig_image.size
image = resize_image(orig_image)
im_np = np.array(image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# Inference
result = net(im_tensor)
# Post-process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# Image to PIL
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# Paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
except Exception as e:
print(f"Error during processing: {e}")
return None
title = "Background Removal"
description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br>
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
"""
examples = [['./input.jpg'],]
demo = gr.Interface(
fn=process,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Textbox(label="Image URL")
],
outputs="image",
examples=examples,
title=title,
description=description
)
if __name__ == "__main__":
demo.launch(share=False) |