File size: 2,153 Bytes
34b61ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch, os, glob, pyiqa
from argparse import ArgumentParser
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms

parser = ArgumentParser()
parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR")
parser.add_argument("--SR_dir", type=str, default="result/RealSR")
args = parser.parse_args()

device = torch.device("cuda")

psnr = pyiqa.create_metric("psnr", test_y_channel=True, color_space="ycbcr", device=device)
ssim = pyiqa.create_metric("ssim", test_y_channel=True, color_space="ycbcr", device=device)
lpips = pyiqa.create_metric("lpips", device=device)
dists = pyiqa.create_metric("dists", device=device)
fid = pyiqa.create_metric("fid", device=device)
niqe = pyiqa.create_metric("niqe", device=device)
maniqa = pyiqa.create_metric("maniqa-pipal", device=device)
clipiqa = pyiqa.create_metric("clipiqa", device=device)
musiq = pyiqa.create_metric("musiq", device=device)

test_SR_paths = list(sorted(glob.glob(os.path.join(args.SR_dir, "*"))))
test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*"))))

metrics = {"psnr": [], "ssim": [], "lpips": [], "dists": [], "niqe": [], "maniqa": [], "musiq": [], "clipiqa": []}

for i, (SR_path, HR_path) in tqdm(enumerate(zip(test_SR_paths, test_HR_paths))):
    SR = Image.open(SR_path).convert("RGB")
    SR = transforms.ToTensor()(SR).to(device).unsqueeze(0)
    HR = Image.open(HR_path).convert("RGB")
    HR = transforms.ToTensor()(HR).to(device).unsqueeze(0)
    metrics["psnr"].append(psnr(SR, HR).item())
    metrics["ssim"].append(ssim(SR, HR).item())
    metrics["lpips"].append(lpips(SR, HR).item())
    metrics["dists"].append(dists(SR, HR).item())
    metrics["niqe"].append(niqe(SR).item())
    metrics["maniqa"].append(maniqa(SR).item())
    metrics["clipiqa"].append(clipiqa(SR).item())
    metrics["musiq"].append(musiq(SR).item())

for k in metrics.keys():
    metrics[k] = np.mean(metrics[k])

metrics["fid"] = fid(args.SR_dir, args.HR_dir)

for k, v in metrics.items():
    if k == "niqe":
        print(k, f"{v:.3g}")
    elif k == "fid":
        print(k, f"{v:.5g}")
    else:
        print(k, f"{v:.4g}")