VenkateshRoshan commited on
Commit
4aaf04f
·
1 Parent(s): 474e221

App updated.

Browse files
Files changed (4) hide show
  1. app.py +84 -0
  2. butterfly.png +0 -0
  3. dockerfile +0 -0
  4. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoModelForImageSegmentation
3
+ from gradio_imageslider import ImageSlider
4
+ import torch
5
+ from torchvision import transforms
6
+ import spaces
7
+ from PIL import Image
8
+
9
+ import numpy as np
10
+ import time
11
+
12
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
13
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
14
+ )
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ print("Using device:", device)
18
+
19
+ birefnet.to(device)
20
+ transform_image = transforms.Compose(
21
+ [
22
+ transforms.Resize((1024, 1024)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
+ ]
26
+ )
27
+
28
+ # @spaces.GPU
29
+ # def PreProcess(image):
30
+ # size = image.size
31
+ # image = transform_image(image).unsqueeze(0).to(device)
32
+
33
+ # with torch.no_grad():
34
+ # preds = birefnet(image)[-1].sigmoid().cpu()
35
+ # pred = preds[0].squeeze()
36
+ # pred = transforms.ToPILImage()(pred)
37
+ # mask = pred.resize(size)
38
+ # # image.putalpha(mask)
39
+ # return image
40
+
41
+ @spaces.GPU
42
+ def PreProcess(image):
43
+ size = image.size # Save original size
44
+ image_tensor = transform_image(image).unsqueeze(0).to(device) # Transform the image into a tensor
45
+
46
+ with torch.no_grad():
47
+ preds = birefnet(image_tensor)[-1].sigmoid().cpu() # Get predictions
48
+ pred = preds[0].squeeze()
49
+
50
+ # Convert the prediction tensor to a PIL image
51
+ pred_pil = transforms.ToPILImage()(pred)
52
+
53
+ # Resize the mask to match the original image size
54
+ mask = pred_pil.resize(size)
55
+
56
+ # Convert the original image (passed as input) to a PIL image
57
+ image_pil = image.convert("RGBA") # Ensure the image has an alpha channel
58
+
59
+ # Apply the alpha mask to the image
60
+ image_pil.putalpha(mask)
61
+
62
+ return image_pil
63
+
64
+ def segment_image(image):
65
+ start = time.time()
66
+ image = Image.fromarray(image)
67
+ image = image.convert("RGB")
68
+ org = image.copy()
69
+ image = PreProcess(image)
70
+ time_taken = np.round((time.time() - start),2)
71
+ return (image, org), time_taken
72
+
73
+ slider = ImageSlider(label='birefnet', type="pil")
74
+ image = gr.Image(label="Upload an Image")
75
+
76
+ butterfly = Image.open("butterfly.png")
77
+
78
+ time_taken = gr.Textbox(label="Time taken", type="text")
79
+
80
+ demo = gr.Interface(
81
+ segment_image, inputs=image, outputs=[slider,time_taken], examples=[butterfly], api_name="BiRefNet")
82
+
83
+ if __name__ == '__main__' :
84
+ demo.launch()
butterfly.png ADDED
dockerfile ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ opencv-python
4
+ spaces
5
+ torchvision
6
+ pillow
7
+ numpy
8
+ huggingface-hub
9
+ gradio
10
+ gradio-imageslider
11
+ transformers
12
+ timm
13
+ kornia