File size: 3,952 Bytes
ec0fdfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import glob
import shutil
import lpips
import numpy as np
import argparse
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from dataloader.image_folder import make_dataset
from util import util
import torch

parser = argparse.ArgumentParser(description='Image quality evaluations on the dataset')
parser.add_argument('--gt_path', type=str, default='../results/', help='path to original gt data')
parser.add_argument('--g_path', type=str, default='../results.', help='path to the generated data')
parser.add_argument('--save_path', type=str, default=None, help='path to save the best results')
parser.add_argument('--center', action='store_true', help='only calculate the center masked regions for the image quality')
parser.add_argument('--num_test', type=int, default=0, help='how many examples to load for testing')

args = parser.parse_args()
lpips_alex = lpips.LPIPS(net='alex')


def calculate_score(img_gt, img_test):
    """
    function to calculate the image quality score
    :param img_gt: original image
    :param img_test: generated image
    :return: mae, ssim, psnr
    """

    l1loss = np.mean(np.abs(img_gt-img_test))

    psnr_score = psnr(img_gt, img_test, data_range=1)

    ssim_score = ssim(img_gt, img_test, multichannel=True, data_range=1, win_size=11)

    lpips_dis = lpips_alex(torch.from_numpy(img_gt).permute(2, 0, 1), torch.from_numpy(img_test).permute(2, 0, 1), normalize=True)

    return l1loss, ssim_score, psnr_score, lpips_dis.data.numpy().item()


if __name__ == '__main__':
    gt_paths, gt_size = make_dataset(args.gt_path)
    g_paths, g_size = make_dataset(args.g_path)

    l1losses = []
    ssims = []
    psnrs = []
    lpipses = []

    size = args.num_test if args.num_test > 0 else gt_size

    for i in range(size):
        gt_img = Image.open(gt_paths[i + 0*2000]).resize([256, 256]).convert('RGB')
        gt_numpy = np.array(gt_img).astype(np.float32) / 255.0
        if args.center:
            gt_numpy = gt_numpy[64:192, 64:192, :]

        l1loss_sample = 1000
        ssim_sample = 0
        psnr_sample = 0
        lpips_sample = 1000

        name = gt_paths[i + 0*2000].split('/')[-1].split(".")[0] + "*"
        g_paths = sorted(glob.glob(os.path.join(args.g_path, name)))
        num_files = len(g_paths)

        for j in range(num_files):
            index = j
            try:
                g_img = Image.open(g_paths[j]).resize([256, 256]).convert('RGB')
                g_numpy = np.array(g_img).astype(np.float32) / 255.0
                if args.center:
                    g_numpy = g_numpy[64:192, 64:192, :]
                l1loss, ssim_score, psnr_score, lpips_score = calculate_score(gt_numpy, g_numpy)
                if l1loss - ssim_score - psnr_score + lpips_score < l1loss_sample - ssim_sample - psnr_sample + lpips_sample:
                    l1loss_sample, ssim_sample, psnr_sample, lpips_sample = l1loss, ssim_score, psnr_score, lpips_score
                    best_index = index
            except:
                print(g_paths[index])

        if l1loss_sample != 1000 and ssim_sample !=0 and psnr_sample != 0:
            print(g_paths[best_index])
            print(l1loss_sample, ssim_sample, psnr_sample, lpips_sample)
            l1losses.append(l1loss_sample)
            ssims.append(ssim_sample)
            psnrs.append(psnr_sample)
            lpipses.append(lpips_sample)

            if args.save_path is not None:
                util.mkdir(args.save_path)
                shutil.copy(g_paths[best_index], args.save_path)

    print('{:>10},{:>10},{:>10},{:>10}'.format('l1loss', 'SSIM', 'PSNR', 'LPIPS'))
    print('{:10.4f},{:10.4f},{:10.4f},{:10.4f}'.format(np.mean(l1losses), np.mean(ssims), np.mean(psnrs), np.mean(lpipses)))
    print('{:10.4f},{:10.4f},{:10.4f},{:10.4f}'.format(np.var(l1losses), np.var(ssims), np.var(psnrs), np.var(lpipses)))