Spaces:
Running
Running
File size: 2,073 Bytes
542c815 3f8e328 542c815 c862667 542c815 42cca0b d6e753e 42cca0b 4189f11 42cca0b 018621a 3267028 b98efed 42cca0b 0ce70fb 4189f11 542c815 988f91c 542c815 4189f11 42cca0b 4189f11 8a70686 4189f11 8a70686 4189f11 42cca0b 4189f11 8a70686 4189f11 42cca0b 4189f11 8a70686 42cca0b 4189f11 42cca0b 4189f11 8a70686 032f16e 4189f11 58f9d23 4189f11 d909bca 4189f11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import numpy as np
import torch
import torch.nn.functional as F
import functools
from torchvision.transforms.functional import normalize
import gradio as gr
from gradio_imageslider import ImageSlider
from briarmbg import BriaRMBG
import PIL
from PIL import Image
from typing import Tuple
import requests
from io import BytesIO
net = BriaRMBG.from_pretrained("briaai/RMBG-1.4")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
@functools.lru_cache()
def get_url_im(url):
user_agent = {'User-agent': 'gradio-app'}
response = requests.get(url, headers=user_agent)
return BytesIO(response.content)
def resize_image(image_url):
image_data = get_url_im(image_url)
image = Image.open(image_data)
image = image.convert('RGB')
model_input_size = (1024, 1024)
image = image.resize(model_input_size, Image.BILINEAR)
return image
def process(image_url):
# prepare input
orig_image = resize_image(image_url)
w, h = orig_im_size = orig_image.size
im_np = np.array(orig_image)
im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
if torch.cuda.is_available():
im_tensor = im_tensor.cuda()
# inference
result = net(im_tensor)
# post process
result = torch.squeeze(F.interpolate(result[0][0], size=(h, w), mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
# image to pil
im_array = (result * 255).cpu().data.numpy().astype(np.uint8)
pil_im = Image.fromarray(np.squeeze(im_array))
# paste the mask on the original image
new_im = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
return new_im
iface = gr.Interface(
fn=process,
inputs=gr.Textbox(label="Text or Image URL"),
outputs=gr.Image(type="pil", label="Output Image"),
)
iface.launch() |