File size: 3,068 Bytes
542c815
3f8e328
542c815
 
 
d6e753e
 
 
8a357d1
487aa58
 
3267028
018621a
3267028
b98efed
542c815
 
 
988f91c
 
542c815
 
04ba376
 
 
 
 
487aa58
 
04ba376
 
487aa58
04ba376
487aa58
 
542c815
487aa58
04ba376
 
 
 
 
 
 
 
 
 
 
 
70974c3
04ba376
 
 
 
 
 
 
 
 
 
 
 
 
542c815
04ba376
 
 
 
d909bca
c530952
6ca28a8
4941fcb
6ca28a8
d909bca
487aa58
 
 
 
5682b91
 
487aa58
 
 
 
 
 
d909bca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
import requests
from io import BytesIO

net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)

def resize_image(image):
    image = image.convert('RGB')
    model_input_size = (1024, 1024)
    image = image.resize(model_input_size, Image.BILINEAR)
    return image

def get_url_image(url):
    headers = {'User-Agent': 'gradio-app'}
    response = requests.get(url, headers=headers)
    return BytesIO(response.content)

def load_image(image_source):
    if isinstance(image_source, str):  # Check if input is a URL
        print(f"Loading image from URL: {image_source}")
        image = Image.open(get_url_image(image_source))
    else:
        print("Loading image from file upload")
        image = Image.fromarray(image_source)
    return image

def process(image_source):
    try:
        # Load and prepare input
        orig_image = load_image(image_source)
        w, h = orig_im_size = orig_image.size
        image = resize_image(orig_image)
        im_np = np.array(image)
        im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
        im_tensor = torch.unsqueeze(im_tensor, 0)
        im_tensor = torch.divide(im_tensor, 255.0)
        im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
        if torch.cuda.is_available():
            im_tensor = im_tensor.cuda()

        # Inference
        result = net(im_tensor)
        # Post-process
        result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
        ma = torch.max(result)
        mi = torch.min(result)
        result = (result - mi) / (ma - mi)
        # Image to PIL
        im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
        pil_im = Image.fromarray(np.squeeze(im_array))
        # Paste the mask on the original image
        new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
        new_im.paste(orig_image, mask=pil_im)

        return new_im
    except Exception as e:
        print(f"Error during processing: {e}")
        return None

title = "Background Removal"
description = r"""Background removal model developed by <a href='https://BRIA.AI' target='_blank'><b>BRIA.AI</b></a>, trained on a carefully selected dataset and is available as an open-source model for non-commercial use.<br> 
For test upload your image and wait. Read more at model card <a href='https://huggingface.co/briaai/RMBG-1.4' target='_blank'><b>briaai/RMBG-1.4</b></a>.<br>
"""
examples = [['./input.jpg'],]

demo = gr.Interface(
    fn=process,
    inputs=[
        gr.Image(type="numpy", label="Upload Image"),
        gr.Textbox(label="Image URL")
    ],
    outputs="image",
    examples=examples,
    title=title,
    description=description
)

if __name__ == "__main__":
    demo.launch(share=False)