ZhengPeng7 commited on
Commit
160b8f2
·
1 Parent(s): e88de74

Add the deployment option.

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. handler.py +129 -0
  3. requirements.txt +18 -0
README.md CHANGED
@@ -7,7 +7,7 @@ tags:
7
  - pytorch_model_hub_mixin
8
  - model_hub_mixin
9
  repo_url: https://github.com/ZhengPeng7/BiRefNet
10
- pipeline_tag: image-segmentation
11
  ---
12
  <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
13
 
 
7
  - pytorch_model_hub_mixin
8
  - model_hub_mixin
9
  repo_url: https://github.com/ZhengPeng7/BiRefNet
10
+ pipeline_tag: image-to-image
11
  ---
12
  <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
13
 
handler.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
+ from typing import Dict, List, Any
3
+ import base64
4
+ from io import BytesIO
5
+ import torch
6
+ from loadimg import load_img
7
+ from torchvision import transforms
8
+ from transformers import AutoModelForImageSegmentation
9
+
10
+ torch.set_float32_matmul_precision(["high", "highest"][0])
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ ### image_proc.py
15
+ def refine_foreground(image, mask, r=90):
16
+ if mask.size != image.size:
17
+ mask = mask.resize(image.size)
18
+ image = np.array(image) / 255.0
19
+ mask = np.array(mask) / 255.0
20
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
21
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
22
+ return image_masked
23
+
24
+
25
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
26
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
27
+ alpha = alpha[:, :, None]
28
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
29
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
30
+
31
+
32
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
33
+ if isinstance(image, Image.Image):
34
+ image = np.array(image) / 255.0
35
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
36
+
37
+ blurred_FA = cv2.blur(F * alpha, (r, r))
38
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
39
+
40
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
41
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
42
+ F = blurred_F + alpha * \
43
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
44
+ F = np.clip(F, 0, 1)
45
+ return F, blurred_B
46
+
47
+
48
+ class ImagePreprocessor():
49
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
50
+ self.transform_image = transforms.Compose([
51
+ transforms.Resize(resolution),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
54
+ ])
55
+
56
+ def proc(self, image: Image.Image) -> torch.Tensor:
57
+ image = self.transform_image(image)
58
+ return image
59
+
60
+ usage_to_weights_file = {
61
+ 'General': 'BiRefNet',
62
+ 'General-Lite': 'BiRefNet_lite',
63
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
64
+ 'General-reso_512': 'BiRefNet-reso_512',
65
+ 'Matting': 'BiRefNet-matting',
66
+ 'Portrait': 'BiRefNet-portrait',
67
+ 'DIS': 'BiRefNet-DIS5K',
68
+ 'HRSOD': 'BiRefNet-HRSOD',
69
+ 'COD': 'BiRefNet-COD',
70
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
71
+ 'General-legacy': 'BiRefNet-legacy'
72
+ }
73
+
74
+ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
75
+ birefnet.to(device)
76
+ birefnet.eval()
77
+
78
+ # Set resolution
79
+ if weights_file in ['General-Lite-2K']:
80
+ resolution = (2560, 1440)
81
+ elif weights_file in ['General-reso_512']:
82
+ resolution = (512, 512)
83
+ else:
84
+ resolution = (1024, 1024)
85
+
86
+
87
+ class EndpointHandler():
88
+ def __init__(self, path=""):
89
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
90
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
91
+ )
92
+ self.birefnet.to(device)
93
+
94
+ def __call__(self, data: Dict[str, Any]):
95
+ """
96
+ data args:
97
+ inputs (:obj: `str`)
98
+ date (:obj: `str`)
99
+ Return:
100
+ A :obj:`list` | `dict`: will be serialized and returned
101
+ """
102
+ print('data["inputs"] = ', data["inputs"])
103
+ image_src = data["inputs"]
104
+ if isinstance(image_src, str):
105
+ if os.path.isfile(image_src):
106
+ image_ori = Image.open(image_src)
107
+ else:
108
+ response = requests.get(image_src)
109
+ image_data = BytesIO(response.content)
110
+ image_ori = Image.open(image_data)
111
+ else:
112
+ image_ori = Image.fromarray(image_src)
113
+
114
+ image = image_ori.convert('RGB')
115
+ # Preprocess the image
116
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
117
+ image_proc = image_preprocessor.proc(image)
118
+ image_proc = image_proc.unsqueeze(0)
119
+
120
+ # Prediction
121
+ with torch.no_grad():
122
+ preds = birefnet(image_proc.to(device))[-1].sigmoid().cpu()
123
+ pred = preds[0].squeeze()
124
+
125
+ # Show Results
126
+ pred_pil = transforms.ToPILImage()(pred)
127
+ image_masked = refine_foreground(image, pred_pil)
128
+ image_masked.putalpha(pred_pil.resize(image.size))
129
+ return image_masked
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.0.1
3
+ --extra-index-url https://download.pytorch.org/whl/cu118
4
+ torchvision==0.15.2
5
+ numpy<2
6
+ opencv-python
7
+ timm
8
+ scipy
9
+ scikit-image
10
+ kornia
11
+ einops
12
+
13
+ tqdm
14
+ prettytable
15
+
16
+ transformers
17
+ huggingface-hub>0.25
18
+ accelerate