Spaces:
Runtime error
Runtime error
import argparse | |
import cv2 | |
import os | |
import numpy as np | |
from skimage.metrics import mean_squared_error | |
from skimage.measure import compare_ssim | |
from skimage.metrics import structural_similarity | |
from skimage.metrics import peak_signal_noise_ratio | |
#import lpips | |
import torch | |
from tqdm import tqdm | |
#from niqe.niqe import compute_niqe | |
#criterion = lpips.LPIPS(net='vgg', lpips=True, pnet_rand=False, pretrained=True).cuda() | |
def rgb2ycbcr(im, only_y=True): | |
''' | |
same as matlab rgb2ycbcr | |
:parame img: uint8 or float ndarray | |
''' | |
in_im_type = im.dtype | |
im = im.astype(np.float64) | |
if in_im_type != np.uint8: | |
im *= 255. | |
# convert | |
if only_y: | |
rlt = np.dot(im, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0 | |
else: | |
rlt = np.matmul(im, np.array([[65.481, -37.797, 112.0 ], | |
[128.553, -74.203, -93.786], | |
[24.966, 112.0, -18.214]])/255.0) + [16, 128, 128] | |
if in_im_type == np.uint8: | |
rlt = rlt.round() | |
else: | |
rlt /= 255. | |
return rlt.astype(in_im_type) | |
def rgb2ycbcrTorch(im, only_y=True): | |
''' | |
same as matlab rgb2ycbcr | |
Input: | |
im: float [0,1], N x 3 x H x W | |
only_y: only return Y channel | |
''' | |
im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C, [0,255] | |
# convert | |
if only_y: | |
rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966], | |
device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0 | |
else: | |
rlt = torch.matmul(im_temp, torch.tensor([[65.481, -37.797, 112.0 ], | |
[128.553, -74.203, -93.786], | |
[24.966, 112.0, -18.214]], | |
device=im.device, dtype=im.dtype)/255.0) + \ | |
torch.tensor([16, 128, 128]).view([-1, 1, 1, 3]) | |
rlt /= 255.0 | |
rlt.clamp_(0.0, 1.0) | |
return rlt.permute([0, 3, 1, 2]) | |
def readim(file): | |
# print(file) | |
img = cv2.imread(file) | |
img = img.astype(np.float32) | |
return img / 255. | |
def loadfiles(folder): | |
files = os.listdir(folder) | |
return natsorted(files) | |
def resize(im, size, crop=True): | |
if crop: | |
return im[:size[0], :size[1]] | |
else: | |
return cv2.resize(im, size) | |
from natsort import natsorted | |
def np2torch(img): | |
im = img.astype(np.float32) / 255 | |
im = torch.tensor(im).permute((2,0,1)).unsqueeze(0) | |
return im.cuda() | |
def compute_metrics(path1, path2, ycbcr=True): | |
print(path1) | |
files1 = loadfiles(path1) | |
files2 = loadfiles(path2) | |
print(len(files1), len(files2)) | |
psnr = [] | |
ssim = [] | |
mse = [] | |
lpips = [] | |
niqe = [] | |
crop = False | |
for file1, file2 in tqdm(zip(files1, files2)): | |
img1 = readim(os.path.join(path1, file1)) | |
img2 = readim(os.path.join(path2, file2)) | |
if img1.shape != img2.shape: | |
if not crop: | |
img1 = resize(img1, img2.shape[:2][::-1], False) | |
else: | |
img1 = resize(img1, img2.shape, True) | |
# print(img1.shape, img2.shape, img1.max()) | |
MSE = mean_squared_error(img1, img2) | |
if ycbcr: | |
img1 = rgb2ycbcr(img1, True) | |
img2 = rgb2ycbcr(img2, True) | |
diff = (img2 - img1) | |
# print(diff.mean(), diff.max(), diff.min(), diff.shape) | |
PSNR = peak_signal_noise_ratio(img1, img2, data_range=1) | |
SSIM = structural_similarity(img1, img2, win_size=11, multichannel=False if ycbcr else True, data_range=1) | |
mse.append(MSE) | |
psnr.append(PSNR) | |
ssim.append(SSIM) | |
mean_mse, mean_psnr, mean_ssim = np.mean(mse), np.mean(psnr), np.mean(ssim) | |
print(mean_mse, mean_psnr, mean_ssim) | |
return mean_mse, mean_psnr, mean_ssim | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# path setting | |
parser.add_argument('--path1', type=str,default= "") # modify the experiments name-->modify all save path | |
parser.add_argument('--path2', type=str,default= "") | |
args = parser.parse_args() | |
path1 = '' | |
path2 = '' | |
if len(args.path1) != 0: | |
path1 = args.path1 | |
if len(args.path2) != 0: | |
path2 = args.path2 | |
compute_metrics(path1, path2, True) | |