File size: 2,408 Bytes
4aaf04f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import pipeline, AutoModelForImageSegmentation 
from gradio_imageslider import ImageSlider
import torch
from torchvision import transforms
import spaces
from PIL import Image

import numpy as np
import time

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

birefnet.to(device)
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# @spaces.GPU
# def PreProcess(image):
#     size = image.size
#     image = transform_image(image).unsqueeze(0).to(device)

#     with torch.no_grad():
#         preds = birefnet(image)[-1].sigmoid().cpu()
#     pred = preds[0].squeeze()
#     pred = transforms.ToPILImage()(pred)
#     mask = pred.resize(size)
#     # image.putalpha(mask)
#     return image

@spaces.GPU
def PreProcess(image):
    size = image.size  # Save original size
    image_tensor = transform_image(image).unsqueeze(0).to(device)  # Transform the image into a tensor

    with torch.no_grad():
        preds = birefnet(image_tensor)[-1].sigmoid().cpu()  # Get predictions
    pred = preds[0].squeeze()

    # Convert the prediction tensor to a PIL image
    pred_pil = transforms.ToPILImage()(pred)

    # Resize the mask to match the original image size
    mask = pred_pil.resize(size)

    # Convert the original image (passed as input) to a PIL image
    image_pil = image.convert("RGBA")  # Ensure the image has an alpha channel

    # Apply the alpha mask to the image
    image_pil.putalpha(mask)

    return image_pil

def segment_image(image):
    start = time.time()
    image = Image.fromarray(image)
    image = image.convert("RGB")
    org = image.copy()
    image = PreProcess(image)
    time_taken = np.round((time.time() - start),2)
    return (image, org), time_taken

slider = ImageSlider(label='birefnet', type="pil")
image = gr.Image(label="Upload an Image")

butterfly = Image.open("butterfly.png")

time_taken = gr.Textbox(label="Time taken", type="text")

demo = gr.Interface(
    segment_image, inputs=image, outputs=[slider,time_taken], examples=[butterfly], api_name="BiRefNet")

if __name__ == '__main__' :
    demo.launch()