File size: 2,153 Bytes
34b61ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch, os, glob, pyiqa
from argparse import ArgumentParser
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
parser = ArgumentParser()
parser.add_argument("--HR_dir", type=str, default="testset/RealSR/HR")
parser.add_argument("--SR_dir", type=str, default="result/RealSR")
args = parser.parse_args()
device = torch.device("cuda")
psnr = pyiqa.create_metric("psnr", test_y_channel=True, color_space="ycbcr", device=device)
ssim = pyiqa.create_metric("ssim", test_y_channel=True, color_space="ycbcr", device=device)
lpips = pyiqa.create_metric("lpips", device=device)
dists = pyiqa.create_metric("dists", device=device)
fid = pyiqa.create_metric("fid", device=device)
niqe = pyiqa.create_metric("niqe", device=device)
maniqa = pyiqa.create_metric("maniqa-pipal", device=device)
clipiqa = pyiqa.create_metric("clipiqa", device=device)
musiq = pyiqa.create_metric("musiq", device=device)
test_SR_paths = list(sorted(glob.glob(os.path.join(args.SR_dir, "*"))))
test_HR_paths = list(sorted(glob.glob(os.path.join(args.HR_dir, "*"))))
metrics = {"psnr": [], "ssim": [], "lpips": [], "dists": [], "niqe": [], "maniqa": [], "musiq": [], "clipiqa": []}
for i, (SR_path, HR_path) in tqdm(enumerate(zip(test_SR_paths, test_HR_paths))):
SR = Image.open(SR_path).convert("RGB")
SR = transforms.ToTensor()(SR).to(device).unsqueeze(0)
HR = Image.open(HR_path).convert("RGB")
HR = transforms.ToTensor()(HR).to(device).unsqueeze(0)
metrics["psnr"].append(psnr(SR, HR).item())
metrics["ssim"].append(ssim(SR, HR).item())
metrics["lpips"].append(lpips(SR, HR).item())
metrics["dists"].append(dists(SR, HR).item())
metrics["niqe"].append(niqe(SR).item())
metrics["maniqa"].append(maniqa(SR).item())
metrics["clipiqa"].append(clipiqa(SR).item())
metrics["musiq"].append(musiq(SR).item())
for k in metrics.keys():
metrics[k] = np.mean(metrics[k])
metrics["fid"] = fid(args.SR_dir, args.HR_dir)
for k, v in metrics.items():
if k == "niqe":
print(k, f"{v:.3g}")
elif k == "fid":
print(k, f"{v:.5g}")
else:
print(k, f"{v:.4g}") |