Imadsarvm commited on
Commit
8590454
·
1 Parent(s): bf22f55

Upload utils_util_calculate_psnr_ssim.py

Browse files
utils/utils_util_calculate_psnr_ssim.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
7
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
8
+
9
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
10
+
11
+ Args:
12
+ img1 (ndarray): Images with range [0, 255].
13
+ img2 (ndarray): Images with range [0, 255].
14
+ crop_border (int): Cropped pixels in each edge of an image. These
15
+ pixels are not involved in the PSNR calculation.
16
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
17
+ Default: 'HWC'.
18
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
19
+
20
+ Returns:
21
+ float: psnr result.
22
+ """
23
+
24
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
25
+ if input_order not in ['HWC', 'CHW']:
26
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
27
+ img1 = reorder_image(img1, input_order=input_order)
28
+ img2 = reorder_image(img2, input_order=input_order)
29
+ img1 = img1.astype(np.float64)
30
+ img2 = img2.astype(np.float64)
31
+
32
+ if crop_border != 0:
33
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
34
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
35
+
36
+ if test_y_channel:
37
+ img1 = to_y_channel(img1)
38
+ img2 = to_y_channel(img2)
39
+
40
+ mse = np.mean((img1 - img2) ** 2)
41
+ if mse == 0:
42
+ return float('inf')
43
+ return 20. * np.log10(255. / np.sqrt(mse))
44
+
45
+
46
+ def _ssim(img1, img2):
47
+ """Calculate SSIM (structural similarity) for one channel images.
48
+
49
+ It is called by func:`calculate_ssim`.
50
+
51
+ Args:
52
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
53
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
54
+
55
+ Returns:
56
+ float: ssim result.
57
+ """
58
+
59
+ C1 = (0.01 * 255) ** 2
60
+ C2 = (0.03 * 255) ** 2
61
+
62
+ img1 = img1.astype(np.float64)
63
+ img2 = img2.astype(np.float64)
64
+ kernel = cv2.getGaussianKernel(11, 1.5)
65
+ window = np.outer(kernel, kernel.transpose())
66
+
67
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
68
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
69
+ mu1_sq = mu1 ** 2
70
+ mu2_sq = mu2 ** 2
71
+ mu1_mu2 = mu1 * mu2
72
+ sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
73
+ sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
74
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
75
+
76
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
77
+ return ssim_map.mean()
78
+
79
+
80
+ def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
81
+ """Calculate SSIM (structural similarity).
82
+
83
+ Ref:
84
+ Image quality assessment: From error visibility to structural similarity
85
+
86
+ The results are the same as that of the official released MATLAB code in
87
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
88
+
89
+ For three-channel images, SSIM is calculated for each channel and then
90
+ averaged.
91
+
92
+ Args:
93
+ img1 (ndarray): Images with range [0, 255].
94
+ img2 (ndarray): Images with range [0, 255].
95
+ crop_border (int): Cropped pixels in each edge of an image. These
96
+ pixels are not involved in the SSIM calculation.
97
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
98
+ Default: 'HWC'.
99
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
100
+
101
+ Returns:
102
+ float: ssim result.
103
+ """
104
+
105
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
106
+ if input_order not in ['HWC', 'CHW']:
107
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
108
+ img1 = reorder_image(img1, input_order=input_order)
109
+ img2 = reorder_image(img2, input_order=input_order)
110
+ img1 = img1.astype(np.float64)
111
+ img2 = img2.astype(np.float64)
112
+
113
+ if crop_border != 0:
114
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
115
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
116
+
117
+ if test_y_channel:
118
+ img1 = to_y_channel(img1)
119
+ img2 = to_y_channel(img2)
120
+
121
+ ssims = []
122
+ for i in range(img1.shape[2]):
123
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
124
+ return np.array(ssims).mean()
125
+
126
+
127
+ def _blocking_effect_factor(im):
128
+ block_size = 8
129
+
130
+ block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
131
+ block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
132
+
133
+ horizontal_block_difference = (
134
+ (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
135
+ 3).sum(2).sum(1)
136
+ vertical_block_difference = (
137
+ (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
138
+ 2).sum(1)
139
+
140
+ nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
141
+ nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
142
+
143
+ horizontal_nonblock_difference = (
144
+ (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
145
+ 3).sum(2).sum(1)
146
+ vertical_nonblock_difference = (
147
+ (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
148
+ 3).sum(2).sum(1)
149
+
150
+ n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
151
+ n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
152
+ boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
153
+ n_boundary_horiz + n_boundary_vert)
154
+
155
+ n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
156
+ n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
157
+ nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
158
+ n_nonboundary_horiz + n_nonboundary_vert)
159
+
160
+ scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
161
+ bef = scaler * (boundary_difference - nonboundary_difference)
162
+
163
+ bef[boundary_difference <= nonboundary_difference] = 0
164
+ return bef
165
+
166
+
167
+ def calculate_psnrb(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
168
+ """Calculate PSNR-B (Peak Signal-to-Noise Ratio).
169
+
170
+ Ref: Quality assessment of deblocked images, for JPEG image deblocking evaluation
171
+ # https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
172
+
173
+ Args:
174
+ img1 (ndarray): Images with range [0, 255].
175
+ img2 (ndarray): Images with range [0, 255].
176
+ crop_border (int): Cropped pixels in each edge of an image. These
177
+ pixels are not involved in the PSNR calculation.
178
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
179
+ Default: 'HWC'.
180
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
181
+
182
+ Returns:
183
+ float: psnr result.
184
+ """
185
+
186
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
187
+ if input_order not in ['HWC', 'CHW']:
188
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
189
+ img1 = reorder_image(img1, input_order=input_order)
190
+ img2 = reorder_image(img2, input_order=input_order)
191
+ img1 = img1.astype(np.float64)
192
+ img2 = img2.astype(np.float64)
193
+
194
+ if crop_border != 0:
195
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
196
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
197
+
198
+ if test_y_channel:
199
+ img1 = to_y_channel(img1)
200
+ img2 = to_y_channel(img2)
201
+
202
+ # follow https://gitlab.com/Queuecumber/quantization-guided-ac/-/blob/master/metrics/psnrb.py
203
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
204
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
205
+
206
+ total = 0
207
+ for c in range(img1.shape[1]):
208
+ mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
209
+ bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
210
+
211
+ mse = mse.view(mse.shape[0], -1).mean(1)
212
+ total += 10 * torch.log10(1 / (mse + bef))
213
+
214
+ return float(total) / img1.shape[1]
215
+
216
+
217
+ def reorder_image(img, input_order='HWC'):
218
+ """Reorder images to 'HWC' order.
219
+
220
+ If the input_order is (h, w), return (h, w, 1);
221
+ If the input_order is (c, h, w), return (h, w, c);
222
+ If the input_order is (h, w, c), return as it is.
223
+
224
+ Args:
225
+ img (ndarray): Input image.
226
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
227
+ If the input image shape is (h, w), input_order will not have
228
+ effects. Default: 'HWC'.
229
+
230
+ Returns:
231
+ ndarray: reordered image.
232
+ """
233
+
234
+ if input_order not in ['HWC', 'CHW']:
235
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
236
+ if len(img.shape) == 2:
237
+ img = img[..., None]
238
+ if input_order == 'CHW':
239
+ img = img.transpose(1, 2, 0)
240
+ return img
241
+
242
+
243
+ def to_y_channel(img):
244
+ """Change to Y channel of YCbCr.
245
+
246
+ Args:
247
+ img (ndarray): Images with range [0, 255].
248
+
249
+ Returns:
250
+ (ndarray): Images with range [0, 255] (float type) without round.
251
+ """
252
+ img = img.astype(np.float32) / 255.
253
+ if img.ndim == 3 and img.shape[2] == 3:
254
+ img = bgr2ycbcr(img, y_only=True)
255
+ img = img[..., None]
256
+ return img * 255.
257
+
258
+
259
+ def _convert_input_type_range(img):
260
+ """Convert the type and range of the input image.
261
+
262
+ It converts the input image to np.float32 type and range of [0, 1].
263
+ It is mainly used for pre-processing the input image in colorspace
264
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
265
+
266
+ Args:
267
+ img (ndarray): The input image. It accepts:
268
+ 1. np.uint8 type with range [0, 255];
269
+ 2. np.float32 type with range [0, 1].
270
+
271
+ Returns:
272
+ (ndarray): The converted image with type of np.float32 and range of
273
+ [0, 1].
274
+ """
275
+ img_type = img.dtype
276
+ img = img.astype(np.float32)
277
+ if img_type == np.float32:
278
+ pass
279
+ elif img_type == np.uint8:
280
+ img /= 255.
281
+ else:
282
+ raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
283
+ return img
284
+
285
+
286
+ def _convert_output_type_range(img, dst_type):
287
+ """Convert the type and range of the image according to dst_type.
288
+
289
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
290
+ images will be converted to np.uint8 type with range [0, 255]. If
291
+ `dst_type` is np.float32, it converts the image to np.float32 type with
292
+ range [0, 1].
293
+ It is mainly used for post-processing images in colorspace convertion
294
+ functions such as rgb2ycbcr and ycbcr2rgb.
295
+
296
+ Args:
297
+ img (ndarray): The image to be converted with np.float32 type and
298
+ range [0, 255].
299
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
300
+ converts the image to np.uint8 type with range [0, 255]. If
301
+ dst_type is np.float32, it converts the image to np.float32 type
302
+ with range [0, 1].
303
+
304
+ Returns:
305
+ (ndarray): The converted image with desired type and range.
306
+ """
307
+ if dst_type not in (np.uint8, np.float32):
308
+ raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
309
+ if dst_type == np.uint8:
310
+ img = img.round()
311
+ else:
312
+ img /= 255.
313
+ return img.astype(dst_type)
314
+
315
+
316
+ def bgr2ycbcr(img, y_only=False):
317
+ """Convert a BGR image to YCbCr image.
318
+
319
+ The bgr version of rgb2ycbcr.
320
+ It implements the ITU-R BT.601 conversion for standard-definition
321
+ television. See more details in
322
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
323
+
324
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
325
+ In OpenCV, it implements a JPEG conversion. See more details in
326
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
327
+
328
+ Args:
329
+ img (ndarray): The input image. It accepts:
330
+ 1. np.uint8 type with range [0, 255];
331
+ 2. np.float32 type with range [0, 1].
332
+ y_only (bool): Whether to only return Y channel. Default: False.
333
+
334
+ Returns:
335
+ ndarray: The converted YCbCr image. The output image has the same type
336
+ and range as input image.
337
+ """
338
+ img_type = img.dtype
339
+ img = _convert_input_type_range(img)
340
+ if y_only:
341
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
342
+ else:
343
+ out_img = np.matmul(
344
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
345
+ out_img = _convert_output_type_range(out_img, img_type)
346
+ return out_img