ZhengPeng7
commited on
Commit
·
160b8f2
1
Parent(s):
e88de74
Add the deployment option.
Browse files- README.md +1 -1
- handler.py +129 -0
- requirements.txt +18 -0
README.md
CHANGED
@@ -7,7 +7,7 @@ tags:
|
|
7 |
- pytorch_model_hub_mixin
|
8 |
- model_hub_mixin
|
9 |
repo_url: https://github.com/ZhengPeng7/BiRefNet
|
10 |
-
pipeline_tag: image-
|
11 |
---
|
12 |
<h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
|
13 |
|
|
|
7 |
- pytorch_model_hub_mixin
|
8 |
- model_hub_mixin
|
9 |
repo_url: https://github.com/ZhengPeng7/BiRefNet
|
10 |
+
pipeline_tag: image-to-image
|
11 |
---
|
12 |
<h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
|
13 |
|
handler.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
|
2 |
+
from typing import Dict, List, Any
|
3 |
+
import base64
|
4 |
+
from io import BytesIO
|
5 |
+
import torch
|
6 |
+
from loadimg import load_img
|
7 |
+
from torchvision import transforms
|
8 |
+
from transformers import AutoModelForImageSegmentation
|
9 |
+
|
10 |
+
torch.set_float32_matmul_precision(["high", "highest"][0])
|
11 |
+
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
### image_proc.py
|
15 |
+
def refine_foreground(image, mask, r=90):
|
16 |
+
if mask.size != image.size:
|
17 |
+
mask = mask.resize(image.size)
|
18 |
+
image = np.array(image) / 255.0
|
19 |
+
mask = np.array(mask) / 255.0
|
20 |
+
estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
|
21 |
+
image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
|
22 |
+
return image_masked
|
23 |
+
|
24 |
+
|
25 |
+
def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
|
26 |
+
# Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
|
27 |
+
alpha = alpha[:, :, None]
|
28 |
+
F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
|
29 |
+
return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
|
30 |
+
|
31 |
+
|
32 |
+
def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
33 |
+
if isinstance(image, Image.Image):
|
34 |
+
image = np.array(image) / 255.0
|
35 |
+
blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
|
36 |
+
|
37 |
+
blurred_FA = cv2.blur(F * alpha, (r, r))
|
38 |
+
blurred_F = blurred_FA / (blurred_alpha + 1e-5)
|
39 |
+
|
40 |
+
blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
|
41 |
+
blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
|
42 |
+
F = blurred_F + alpha * \
|
43 |
+
(image - alpha * blurred_F - (1 - alpha) * blurred_B)
|
44 |
+
F = np.clip(F, 0, 1)
|
45 |
+
return F, blurred_B
|
46 |
+
|
47 |
+
|
48 |
+
class ImagePreprocessor():
|
49 |
+
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
50 |
+
self.transform_image = transforms.Compose([
|
51 |
+
transforms.Resize(resolution),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
54 |
+
])
|
55 |
+
|
56 |
+
def proc(self, image: Image.Image) -> torch.Tensor:
|
57 |
+
image = self.transform_image(image)
|
58 |
+
return image
|
59 |
+
|
60 |
+
usage_to_weights_file = {
|
61 |
+
'General': 'BiRefNet',
|
62 |
+
'General-Lite': 'BiRefNet_lite',
|
63 |
+
'General-Lite-2K': 'BiRefNet_lite-2K',
|
64 |
+
'General-reso_512': 'BiRefNet-reso_512',
|
65 |
+
'Matting': 'BiRefNet-matting',
|
66 |
+
'Portrait': 'BiRefNet-portrait',
|
67 |
+
'DIS': 'BiRefNet-DIS5K',
|
68 |
+
'HRSOD': 'BiRefNet-HRSOD',
|
69 |
+
'COD': 'BiRefNet-COD',
|
70 |
+
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
|
71 |
+
'General-legacy': 'BiRefNet-legacy'
|
72 |
+
}
|
73 |
+
|
74 |
+
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
75 |
+
birefnet.to(device)
|
76 |
+
birefnet.eval()
|
77 |
+
|
78 |
+
# Set resolution
|
79 |
+
if weights_file in ['General-Lite-2K']:
|
80 |
+
resolution = (2560, 1440)
|
81 |
+
elif weights_file in ['General-reso_512']:
|
82 |
+
resolution = (512, 512)
|
83 |
+
else:
|
84 |
+
resolution = (1024, 1024)
|
85 |
+
|
86 |
+
|
87 |
+
class EndpointHandler():
|
88 |
+
def __init__(self, path=""):
|
89 |
+
self.birefnet = AutoModelForImageSegmentation.from_pretrained(
|
90 |
+
"ZhengPeng7/BiRefNet", trust_remote_code=True
|
91 |
+
)
|
92 |
+
self.birefnet.to(device)
|
93 |
+
|
94 |
+
def __call__(self, data: Dict[str, Any]):
|
95 |
+
"""
|
96 |
+
data args:
|
97 |
+
inputs (:obj: `str`)
|
98 |
+
date (:obj: `str`)
|
99 |
+
Return:
|
100 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
101 |
+
"""
|
102 |
+
print('data["inputs"] = ', data["inputs"])
|
103 |
+
image_src = data["inputs"]
|
104 |
+
if isinstance(image_src, str):
|
105 |
+
if os.path.isfile(image_src):
|
106 |
+
image_ori = Image.open(image_src)
|
107 |
+
else:
|
108 |
+
response = requests.get(image_src)
|
109 |
+
image_data = BytesIO(response.content)
|
110 |
+
image_ori = Image.open(image_data)
|
111 |
+
else:
|
112 |
+
image_ori = Image.fromarray(image_src)
|
113 |
+
|
114 |
+
image = image_ori.convert('RGB')
|
115 |
+
# Preprocess the image
|
116 |
+
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
117 |
+
image_proc = image_preprocessor.proc(image)
|
118 |
+
image_proc = image_proc.unsqueeze(0)
|
119 |
+
|
120 |
+
# Prediction
|
121 |
+
with torch.no_grad():
|
122 |
+
preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
|
123 |
+
pred = preds[0].squeeze()
|
124 |
+
|
125 |
+
# Show Results
|
126 |
+
pred_pil = transforms.ToPILImage()(pred)
|
127 |
+
image_masked = refine_foreground(image, pred_pil)
|
128 |
+
image_masked.putalpha(pred_pil.resize(image.size))
|
129 |
+
return image_masked
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
+
torch==2.0.1
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
4 |
+
torchvision==0.15.2
|
5 |
+
numpy<2
|
6 |
+
opencv-python
|
7 |
+
timm
|
8 |
+
scipy
|
9 |
+
scikit-image
|
10 |
+
kornia
|
11 |
+
einops
|
12 |
+
|
13 |
+
tqdm
|
14 |
+
prettytable
|
15 |
+
|
16 |
+
transformers
|
17 |
+
huggingface-hub>0.25
|
18 |
+
accelerate
|