File size: 3,957 Bytes
4730cdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-08-13 21:37:58
'''
Calculate LPIPS, and FID.
'''
import os, sys, math
import lpips
import pyiqa
import pickle
import argparse
import numpy as np
from scipy import linalg
from pathlib import Path
from loguru import logger as base_logger
import torch
import torch.nn as nn
sys.path.append(str(Path(__file__).resolve().parents[1]))
from utils import util_image
def load_im_tensor(im_path):
"""
Load image and normalize to [-1, 1]
"""
im = util_image.imread(im_path, chn='rgb', dtype='float32')
im = torch.from_numpy(im).permute(2,0,1).unsqueeze(0).cuda()
im = (im - 0.5) / 0.5
return im
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--gt_dir", type=str, default="", help="Path to save the HQ images")
parser.add_argument("--sr_dir", type=str, default="", help="Path to save the SR images")
args = parser.parse_args()
# setting logger
log_path = str(Path(args.sr_dir).parent / 'metrics.log')
logger = base_logger
logger.remove()
logger.add(log_path, format="{time:YYYY-MM-DD(HH:mm:ss)}: {message}", mode='w', level='INFO')
logger.add(sys.stderr, format="{message}", level='INFO')
for key in vars(args):
value = getattr(args, key)
logger.info(f'{key}: {value}')
lpips_metric_vgg = lpips.LPIPS(net='vgg').cuda()
lpips_metric_alex = lpips.LPIPS(net='alex').cuda()
clipiqa_metric = pyiqa.create_metric('clipiqa')
musiq_metric = pyiqa.create_metric('musiq')
info_path = Path(args.gt_dir).parent / 'infos' / 'mask_split.pkl'
with open(str(info_path), mode='rb') as ff:
mask_split = pickle.load(ff)
mean_lpips_vgg = 0
mean_lpips_alex = 0
mean_clipiqa = 0
mean_musiq = 0
num_mask_types = 0
for mask_type in mask_split.keys():
num_mask_types += 1
im_path_list = [(Path(args.sr_dir) / im_name) for im_name in mask_split[mask_type]]
logger.info(f"Mask types: {mask_type}, images: {len(im_path_list)}")
features = []
current_lpips_vgg = 0
current_lpips_alex = 0
current_clipiqa = 0
current_musiq = 0
for im_path_sr in im_path_list:
im_sr = load_im_tensor(im_path_sr)
im_path_gt = Path(args.gt_dir) / im_path_sr.name
im_gt = load_im_tensor(im_path_gt)
with torch.no_grad():
# calculate lpips
current_lpips_vgg += lpips_metric_vgg(im_gt, im_sr).sum().item()
current_lpips_alex += lpips_metric_alex(im_gt, im_sr).sum().item()
# calculate clipiqa
current_clipiqa += clipiqa_metric(im_sr).sum().item()
# calculate musiq
current_musiq += musiq_metric(im_sr).sum().item()
# calculate average lpips score
current_lpips_vgg /= len(im_path_list)
mean_lpips_vgg += current_lpips_vgg
current_lpips_alex /= len(im_path_list)
mean_lpips_alex += current_lpips_alex
# calculate average clipiqa score
current_clipiqa /= len(im_path_list)
mean_clipiqa += current_clipiqa
# calculate average musiq score
current_musiq /= len(im_path_list)
mean_musiq += current_musiq
logger.info(f" LPIPS-VGG: {current_lpips_vgg:6.4f}")
logger.info(f" LPIPS-Alex: {current_lpips_alex:6.4f}")
logger.info(f" CLIPIQA: {current_clipiqa:6.4f}")
logger.info(f" MUSIQ: {current_musiq:5.2f}")
mean_lpips_vgg /= num_mask_types
mean_lpips_alex /= num_mask_types
mean_clipiqa /= num_mask_types
mean_musiq /= num_mask_types
logger.info(f"MEAN LPIPS-VGG: {mean_lpips_vgg:6.4f}")
logger.info(f"MEAN LPIPS-Alex: {mean_lpips_alex:6.4f}")
logger.info(f"MEAN CLIPIQA: {mean_clipiqa:6.4f}")
logger.info(f"MEAN MUSIQ: {mean_musiq:5.2f}")
if __name__ == "__main__":
main()
|