ghostsInTheMachine commited on
Commit
0972107
·
verified ·
1 Parent(s): 73527b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -162
app.py CHANGED
@@ -3,27 +3,16 @@ import cv2
3
  import numpy as np
4
  import torch
5
  import gradio as gr
6
- import spaces
7
-
8
- from glob import glob
9
- from typing import Tuple
10
 
11
  from PIL import Image
12
- from gradio_imageslider import ImageSlider
13
  from transformers import AutoModelForImageSegmentation
14
  from torchvision import transforms
15
 
16
- import requests
17
- from io import BytesIO
18
- import zipfile
19
-
20
-
21
  torch.set_float32_matmul_precision('high')
22
  torch.jit.script = lambda f: f
23
 
24
  device = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
- ### image_proc.py
27
  def refine_foreground(image, mask, r=90):
28
  if mask.size != image.size:
29
  mask = mask.resize(image.size)
@@ -33,15 +22,12 @@ def refine_foreground(image, mask, r=90):
33
  image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
34
  return image_masked
35
 
36
-
37
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
38
- # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
39
  alpha = alpha[:, :, None]
40
  F, blur_B = FB_blur_fusion_foreground_estimator(
41
  image, image, image, alpha, r)
42
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
43
 
44
-
45
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
46
  if isinstance(image, Image.Image):
47
  image = np.array(image) / 255.0
@@ -57,9 +43,8 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
57
  F = np.clip(F, 0, 1)
58
  return F, blurred_B
59
 
60
-
61
  class ImagePreprocessor():
62
- def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
63
  self.transform_image = transforms.Compose([
64
  transforms.Resize(resolution),
65
  transforms.ToTensor(),
@@ -70,159 +55,47 @@ class ImagePreprocessor():
70
  image = self.transform_image(image)
71
  return image
72
 
73
-
74
- usage_to_weights_file = {
75
- 'General': 'BiRefNet',
76
- 'General-Lite': 'BiRefNet_lite',
77
- 'General-Lite-2K': 'BiRefNet_lite-2K',
78
- 'Matting': 'BiRefNet-matting',
79
- 'Portrait': 'BiRefNet-portrait',
80
- 'DIS': 'BiRefNet-DIS5K',
81
- 'HRSOD': 'BiRefNet-HRSOD',
82
- 'COD': 'BiRefNet-COD',
83
- 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
84
- 'General-legacy': 'BiRefNet-legacy'
85
- }
86
-
87
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
88
  birefnet.to(device)
89
  birefnet.eval()
90
 
 
 
 
91
 
92
- @spaces.GPU
93
- def predict(images, resolution, weights_file):
94
- assert (images is not None), 'AssertionError: images cannot be None.'
95
-
96
- global birefnet
97
- # Load BiRefNet with chosen weights
98
- _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
99
- print('Using weights: {}.'.format(_weights_file))
100
- birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
101
- birefnet.to(device)
102
- birefnet.eval()
103
-
104
- try:
105
- resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
106
- except:
107
- resolution = (1024, 1024) if weights_file not in ['General-Lite-2K'] else (2560, 1440)
108
- print('Invalid resolution input. Automatically changed to 1024x1024 or 2K.')
109
-
110
- if isinstance(images, list):
111
- # For tab_batch
112
- save_paths = []
113
- save_dir = 'preds-BiRefNet'
114
- if not os.path.exists(save_dir):
115
- os.makedirs(save_dir)
116
- tab_is_batch = True
117
- else:
118
- images = [images]
119
- tab_is_batch = False
120
-
121
- for idx_image, image_src in enumerate(images):
122
- if isinstance(image_src, str):
123
- if os.path.isfile(image_src):
124
- image_ori = Image.open(image_src)
125
- else:
126
- response = requests.get(image_src)
127
- image_data = BytesIO(response.content)
128
- image_ori = Image.open(image_data)
129
- else:
130
- image_ori = Image.fromarray(image_src)
131
-
132
- image = image_ori.convert('RGB')
133
- # Preprocess the image
134
- image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
135
- image_proc = image_preprocessor.proc(image)
136
- image_proc = image_proc.unsqueeze(0)
137
-
138
- # Prediction
139
- with torch.no_grad():
140
- preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
141
- pred = preds[0].squeeze()
142
-
143
- # Show Results
144
- pred_pil = transforms.ToPILImage()(pred)
145
- image_masked = refine_foreground(image, pred_pil)
146
- image_masked.putalpha(pred_pil.resize(image.size))
147
-
148
- torch.cuda.empty_cache()
149
-
150
- if tab_is_batch:
151
- save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
152
- image_masked.save(save_file_path)
153
- save_paths.append(save_file_path)
154
-
155
- if tab_is_batch:
156
- zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
157
- with zipfile.ZipFile(zip_file_path, 'w') as zipf:
158
- for file in save_paths:
159
- zipf.write(file, os.path.basename(file))
160
- return save_paths, zip_file_path
161
- else:
162
- return (image_masked, image_ori)
163
-
164
-
165
- examples = [[_] for _ in glob('examples/*')][:]
166
- # Add the option of resolution in a text box.
167
- for idx_example, example in enumerate(examples):
168
- examples[idx_example].append('1024x1024')
169
- examples.append(examples[-1].copy())
170
- examples[-1][1] = '512x512'
171
-
172
- examples_url = [
173
- ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
174
- ]
175
- for idx_example_url, example_url in enumerate(examples_url):
176
- examples_url[idx_example_url].append('1024x1024')
177
-
178
- descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
179
- ' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
180
- ' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
181
- ' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
182
-
183
- tab_image = gr.Interface(
184
- fn=predict,
185
- inputs=[
186
- gr.Image(label='Upload an image'),
187
- gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
188
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
189
- ],
190
- outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
191
- examples=examples,
192
- api_name="image",
193
- description=descriptions,
194
- )
195
 
196
- tab_text = gr.Interface(
197
- fn=predict,
198
- inputs=[
199
- gr.Textbox(label="Paste an image URL"),
200
- gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
201
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
202
- ],
203
- outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
204
- examples=examples_url,
205
- api_name="text",
206
- description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
207
- )
208
 
209
- tab_batch = gr.Interface(
210
- fn=predict,
211
- inputs=[
212
- gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
213
- gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
214
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
215
- ],
216
- outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
217
- api_name="batch",
218
- description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
219
- )
220
 
221
- demo = gr.TabbedInterface(
222
- [tab_image, tab_text, tab_batch],
223
- ['image', 'text', 'batch'],
224
- title="BiRefNet demo for subject extraction (general / matting / salient / camouflaged / portrait).",
 
 
 
 
 
 
 
 
 
 
225
  )
226
 
227
  if __name__ == "__main__":
228
- demo.launch(debug=True)
 
3
  import numpy as np
4
  import torch
5
  import gradio as gr
 
 
 
 
6
 
7
  from PIL import Image
 
8
  from transformers import AutoModelForImageSegmentation
9
  from torchvision import transforms
10
 
 
 
 
 
 
11
  torch.set_float32_matmul_precision('high')
12
  torch.jit.script = lambda f: f
13
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
 
16
  def refine_foreground(image, mask, r=90):
17
  if mask.size != image.size:
18
  mask = mask.resize(image.size)
 
22
  image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
23
  return image_masked
24
 
 
25
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
 
26
  alpha = alpha[:, :, None]
27
  F, blur_B = FB_blur_fusion_foreground_estimator(
28
  image, image, image, alpha, r)
29
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
30
 
 
31
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
32
  if isinstance(image, Image.Image):
33
  image = np.array(image) / 255.0
 
43
  F = np.clip(F, 0, 1)
44
  return F, blurred_B
45
 
 
46
  class ImagePreprocessor():
47
+ def __init__(self, resolution=(1024, 1024)) -> None:
48
  self.transform_image = transforms.Compose([
49
  transforms.Resize(resolution),
50
  transforms.ToTensor(),
 
55
  image = self.transform_image(image)
56
  return image
57
 
58
+ birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet-matting', trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  birefnet.to(device)
60
  birefnet.eval()
61
 
62
+ def predict(image):
63
+ if image is None:
64
+ raise gr.Error("Please upload an image.")
65
 
66
+ image_ori = Image.fromarray(image)
67
+ image = image_ori.convert('RGB')
68
+
69
+ # Preprocess the image
70
+ image_preprocessor = ImagePreprocessor(resolution=(1024, 1024))
71
+ image_proc = image_preprocessor.proc(image)
72
+ image_proc = image_proc.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ # Prediction
75
+ with torch.no_grad():
76
+ preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
77
+ pred = preds[0].squeeze()
 
 
 
 
 
 
 
 
78
 
79
+ # Show Results
80
+ pred_pil = transforms.ToPILImage()(pred)
81
+ image_masked = refine_foreground(image, pred_pil)
82
+ image_masked.putalpha(pred_pil.resize(image.size))
 
 
 
 
 
 
 
83
 
84
+ torch.cuda.empty_cache()
85
+
86
+ # Save as PNG
87
+ output_path = "output.png"
88
+ image_masked.save(output_path)
89
+
90
+ return output_path
91
+
92
+ iface = gr.Interface(
93
+ fn=predict,
94
+ inputs=gr.Image(type="numpy"),
95
+ outputs=gr.Image(type="filepath"),
96
+ title="BiRefNet Matting",
97
+ description="Upload an image to perform matting using BiRefNet."
98
  )
99
 
100
  if __name__ == "__main__":
101
+ iface.launch(debug=True)