shunk031 commited on
Commit
70c20dc
·
verified ·
1 Parent(s): 695da21

Upload processor

Browse files
image_processing_basnet.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Union
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from PIL.Image import Image as PilImage
8
+ from torchvision import transforms
9
+ from transformers.image_processing_base import BatchFeature
10
+ from transformers.image_processing_utils import BaseImageProcessor
11
+ from transformers.image_utils import ImageInput
12
+
13
+
14
+ class RescaleT(object):
15
+ def __init__(self, output_size: Union[int, Tuple[int, int]]) -> None:
16
+ super().__init__()
17
+ assert isinstance(output_size, (int, tuple))
18
+ self.output_size = output_size
19
+
20
+ def __call__(self, sample) -> Dict[str, np.ndarray]:
21
+ image, label = sample["image"], sample["label"]
22
+
23
+ h, w = image.shape[:2]
24
+
25
+ if isinstance(self.output_size, int):
26
+ if h > w:
27
+ new_h, new_w = self.output_size * h / w, self.output_size
28
+ else:
29
+ new_h, new_w = self.output_size, self.output_size * w / h
30
+ else:
31
+ new_h, new_w = self.output_size
32
+
33
+ new_h, new_w = int(new_h), int(new_w)
34
+
35
+ # resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
36
+ # img = transform.resize(image,(new_h,new_w),mode='constant')
37
+ # lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
38
+
39
+ # img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
40
+ img = (
41
+ cv2.resize(
42
+ image,
43
+ (self.output_size, self.output_size),
44
+ interpolation=cv2.INTER_AREA,
45
+ )
46
+ / 255.0
47
+ )
48
+ # lbl = transform.resize(label, (self.output_size, self.output_size),
49
+ # mode='constant',
50
+ # order=0,
51
+ # preserve_range=True)
52
+ lbl = cv2.resize(
53
+ label, (self.output_size, self.output_size), interpolation=cv2.INTER_NEAREST
54
+ )
55
+ lbl = np.expand_dims(lbl, axis=-1)
56
+ lbl = np.clip(lbl, np.min(label), np.max(label))
57
+
58
+ return {"image": img, "label": lbl}
59
+
60
+
61
+ class ToTensorLab(object):
62
+ """Convert ndarrays in sample to Tensors."""
63
+
64
+ def __init__(self, flag: int = 0) -> None:
65
+ self.flag = flag
66
+
67
+ def __call__(self, sample):
68
+ image, label = sample["image"], sample["label"]
69
+
70
+ tmpLbl = np.zeros(label.shape)
71
+
72
+ if np.max(label) < 1e-6:
73
+ label = label
74
+ else:
75
+ label = label / np.max(label)
76
+
77
+ # change the color space
78
+ if self.flag == 2: # with rgb and Lab colors
79
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
80
+ tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
81
+ if image.shape[2] == 1:
82
+ tmpImgt[:, :, 0] = image[:, :, 0]
83
+ tmpImgt[:, :, 1] = image[:, :, 0]
84
+ tmpImgt[:, :, 2] = image[:, :, 0]
85
+ else:
86
+ tmpImgt = image
87
+ # tmpImgtl = color.rgb2lab(tmpImgt)
88
+ tmpImgtl = cv2.cvtColor(tmpImgt, cv2.COLOR_RGB2LAB)
89
+
90
+ # nomalize image to range [0,1]
91
+ tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
92
+ np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0])
93
+ )
94
+ tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
95
+ np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1])
96
+ )
97
+ tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
98
+ np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2])
99
+ )
100
+ tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
101
+ np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0])
102
+ )
103
+ tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
104
+ np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1])
105
+ )
106
+ tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
107
+ np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2])
108
+ )
109
+
110
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
111
+
112
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
113
+ tmpImg[:, :, 0]
114
+ )
115
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
116
+ tmpImg[:, :, 1]
117
+ )
118
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
119
+ tmpImg[:, :, 2]
120
+ )
121
+ tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(
122
+ tmpImg[:, :, 3]
123
+ )
124
+ tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(
125
+ tmpImg[:, :, 4]
126
+ )
127
+ tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(
128
+ tmpImg[:, :, 5]
129
+ )
130
+
131
+ elif self.flag == 1: # with Lab color
132
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
133
+
134
+ if image.shape[2] == 1:
135
+ tmpImg[:, :, 0] = image[:, :, 0]
136
+ tmpImg[:, :, 1] = image[:, :, 0]
137
+ tmpImg[:, :, 2] = image[:, :, 0]
138
+ else:
139
+ tmpImg = image
140
+
141
+ # tmpImg = color.rgb2lab(tmpImg)
142
+ print("tmpImg:", tmpImg.min(), tmpImg.max())
143
+ exit()
144
+ tmpImg = cv2.cvtColor(tmpImg, cv2.COLOR_RGB2LAB)
145
+
146
+ # tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
147
+
148
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
149
+ np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0])
150
+ )
151
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
152
+ np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1])
153
+ )
154
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
155
+ np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2])
156
+ )
157
+
158
+ tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(
159
+ tmpImg[:, :, 0]
160
+ )
161
+ tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(
162
+ tmpImg[:, :, 1]
163
+ )
164
+ tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(
165
+ tmpImg[:, :, 2]
166
+ )
167
+
168
+ else: # with rgb color
169
+ tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
170
+ image = image / np.max(image)
171
+ if image.shape[2] == 1:
172
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
173
+ tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
174
+ tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
175
+ else:
176
+ tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
177
+ tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
178
+ tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
179
+
180
+ tmpLbl[:, :, 0] = label[:, :, 0]
181
+
182
+ # change the r,g,b to b,r,g from [0,255] to [0,1]
183
+ # transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
184
+ tmpImg = tmpImg.transpose((2, 0, 1))
185
+ tmpLbl = label.transpose((2, 0, 1))
186
+
187
+ return {"image": torch.from_numpy(tmpImg), "label": torch.from_numpy(tmpLbl)}
188
+
189
+
190
+ def apply_transform(
191
+ data: Dict[str, np.ndarray], rescale_size: int, to_tensor_lab_flag: int
192
+ ) -> Dict[str, torch.Tensor]:
193
+ transform = transforms.Compose(
194
+ [RescaleT(output_size=rescale_size), ToTensorLab(flag=to_tensor_lab_flag)]
195
+ )
196
+ return transform(data) # type: ignore
197
+
198
+
199
+ class BASNetImageProcessor(BaseImageProcessor):
200
+ model_input_names = ["pixel_values"]
201
+
202
+ def __init__(
203
+ self, rescale_size: int = 256, to_tensor_lab_flag: int = 0, **kwargs
204
+ ) -> None:
205
+ super().__init__(**kwargs)
206
+ self.rescale_size = rescale_size
207
+ self.to_tensor_lab_flag = to_tensor_lab_flag
208
+
209
+ def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
210
+ if not isinstance(images, PilImage):
211
+ raise ValueError(f"Expected PIL.Image, got {type(images)}")
212
+
213
+ image_pil = images
214
+ image_npy = np.array(image_pil, dtype=np.uint8)
215
+ width, height = image_pil.size
216
+ label_npy = np.zeros((height, width), dtype=np.uint8)
217
+
218
+ assert image_npy.shape[-1] == 3
219
+ output = apply_transform(
220
+ {"image": image_npy, "label": label_npy},
221
+ rescale_size=self.rescale_size,
222
+ to_tensor_lab_flag=self.to_tensor_lab_flag,
223
+ )
224
+ image = output["image"]
225
+
226
+ assert isinstance(image, torch.Tensor)
227
+
228
+ return BatchFeature(
229
+ data={"pixel_values": image.float().unsqueeze(dim=0)}, tensor_type="pt"
230
+ )
231
+
232
+ def postprocess(
233
+ self, prediction: torch.Tensor, width: int, height: int
234
+ ) -> PilImage:
235
+ def _norm_prediction(d: torch.Tensor) -> torch.Tensor:
236
+ ma, mi = torch.max(d), torch.min(d)
237
+
238
+ # division while avoiding zero division
239
+ dn = (d - mi) / ((ma - mi) + torch.finfo(torch.float32).eps)
240
+ return dn
241
+
242
+ # prediction = _norm_output(prediction)
243
+ # prediction = prediction.squeeze()
244
+ # prediction_np = prediction.cpu().numpy()
245
+
246
+ # image = Image.fromarray(prediction_np * 255).convert("RGB")
247
+ # image = image.resize((width, height), resample=Image.Resampling.BILINEAR)
248
+
249
+ # return image
250
+
251
+ # breakpoint()
252
+
253
+ # output = F.interpolate(output, (height, width), mode="bilinear")
254
+ # output = output.squeeze(dim=0)
255
+
256
+ # output = _norm_output(output)
257
+
258
+ # # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
259
+ # output = output * 255 + 0.5
260
+ # output = output.clamp(0, 255)
261
+
262
+ # # shape: (C=1, W, H) -> (W, H, C=1)
263
+ # output = output.permute(1, 2, 0)
264
+ # # shape: (W, H, C=3)
265
+ # output = output.repeat(1, 1, 3)
266
+
267
+ # output_np = output.cpu().numpy().astype(np.uint8)
268
+ # return Image.fromarray(output_np)
269
+
270
+ prediction = _norm_prediction(prediction)
271
+ prediction = prediction.squeeze()
272
+ prediction = prediction * 255 + 0.5
273
+ prediction = prediction.clamp(0, 255)
274
+
275
+ prediction_np = prediction.cpu().numpy()
276
+ image = Image.fromarray(prediction_np).convert("RGB")
277
+ image = image.resize((width, height), resample=Image.Resampling.BILINEAR)
278
+ return image
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_basnet.BASNetImageProcessor"
4
+ },
5
+ "image_processor_type": "BASNetImageProcessor",
6
+ "rescale_size": 256,
7
+ "to_tensor_lab_flag": 0
8
+ }