LHM / tools /metrics /compute_psnr.py
QZFantasies's picture
add wheels
c614b0f
# -*- coding: utf-8 -*-
# @Organization : Alibaba XR-Lab
# @Author : Lingteng Qiu
# @Email : [email protected]
# @Time : 2025-03-03 10:28:35
# @Function : Easy to use PSNR metric
import os
import sys
sys.path.append("./")
import math
import pdb
import cv2
import numpy as np
import skimage
import torch
from PIL import Image
from tqdm import tqdm
def write_json(path, x):
"""write a json file.
Args:
path (str): path to write json file.
x (dict): dict to write.
"""
import json
with open(path, "w") as f:
json.dump(x, f, indent=2)
def img_center_padding(img_np, pad_ratio=0.2, background=1):
ori_w, ori_h = img_np.shape[:2]
w = round((1 + pad_ratio) * ori_w)
h = round((1 + pad_ratio) * ori_h)
if background == 1:
img_pad_np = np.ones((w, h, 3), dtype=img_np.dtype)
else:
img_pad_np = np.zeros((w, h, 3), dtype=img_np.dtype)
offset_h, offset_w = (w - img_np.shape[0]) // 2, (h - img_np.shape[1]) // 2
img_pad_np[
offset_h : offset_h + img_np.shape[0] :, offset_w : offset_w + img_np.shape[1]
] = img_np
return img_pad_np, offset_w, offset_h
def compute_psnr(src, tar):
psnr = skimage.metrics.peak_signal_noise_ratio(tar, src, data_range=1)
return psnr
def get_parse():
import argparse
parser = argparse.ArgumentParser(description="")
parser.add_argument("-f1", "--folder1", required=True, help="input path")
parser.add_argument("-f2", "--folder2", required=True, help="output path")
parser.add_argument("-m", "--mask", default=None, help="output path")
parser.add_argument("--pre", default="")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--pad", action="store_true", help="if the gt pad?")
args = parser.parse_args()
return args
def get_image_paths_current_dir(folder_path):
image_extensions = {
".jpg",
".jpeg",
".png",
".gif",
".bmp",
".tiff",
".webp",
".jfif",
}
return sorted(
[
os.path.join(folder_path, f)
for f in os.listdir(folder_path)
if os.path.splitext(f)[1].lower() in image_extensions
]
)
def psnr_compute(
input_data,
results_data,
mask_data=None,
pad=False,
):
gt_imgs = get_image_paths_current_dir(input_data)
result_imgs = get_image_paths_current_dir(os.path.join(results_data))
if mask_data is not None:
mask_imgs = get_image_paths_current_dir(mask_data)
else:
mask_imgs = None
if "visualization" in result_imgs[-1]:
result_imgs = result_imgs[:-1]
if len(gt_imgs) != len(result_imgs):
return -1
psnr_mean = []
for mask_i, (gt, result) in tqdm(enumerate(zip(gt_imgs, result_imgs))):
result_img = (cv2.imread(result, cv2.IMREAD_UNCHANGED) / 255.0).astype(
np.float32
)
gt_img = (cv2.imread(gt, cv2.IMREAD_UNCHANGED) / 255.0).astype(np.float32)
if mask_imgs is not None:
mask_img = (
cv2.imread(mask_imgs[mask_i], cv2.IMREAD_UNCHANGED) / 255.0
).astype(np.float32)
mask_img = mask_img[..., -1]
mask_img = np.stack([mask_img] * 3, axis=-1)
mask_img, _, _ = img_center_padding(mask_img, background=0)
if pad:
gt_img, _, _ = img_center_padding(gt_img)
h, w, c = result_img.shape
scale_h = int(h * 512 / w)
gt_img = cv2.resize(gt_img, (512, scale_h), interpolation=cv2.INTER_AREA)
result_img = cv2.resize(
result_img, (512, scale_h), interpolation=cv2.INTER_AREA
)
if mask_imgs is not None:
mask_img = cv2.resize(mask_img, (w, h), interpolation=cv2.INTER_AREA)
gt_img = gt_img * mask_img + 1 - mask_img
result_img = result_img * mask_img + 1 - mask_img
mask_label = mask_img[..., 0]
psnr_mean += [
compute_psnr(result_img[mask_label > 0.5], gt_img[mask_label > 0.5])
]
else:
psnr_mean += [compute_psnr(result_img, gt_img)]
psnr = np.mean(psnr_mean)
return psnr
if __name__ == "__main__":
opt = get_parse()
input_folder = opt.folder1
target_folder = opt.folder2
mask_folder = opt.mask
valid_txt = os.path.join(input_folder, "front_view.txt")
target_folder = target_folder[:-1] if target_folder[-1] == "/" else target_folder
if mask_folder is not None:
mask_folder = mask_folder[:-1] if mask_folder[-1] == "/" else mask_folder
target_key = target_folder.split("/")[-2:]
save_folder = os.path.join(f"./exps/metrics{opt.pre}", "psnr_results", *target_key)
os.makedirs(save_folder, exist_ok=True)
with open(valid_txt) as f:
items = f.read().splitlines()
items = [x.split(" ")[0] for x in items]
results_dict = dict()
psnr_list = []
for item in items:
input_item_folder = os.path.join(input_folder, item)
if mask_folder is not None:
mask_item_folder = os.path.join(mask_folder, item)
else:
mask_item_folder = None
target_item_folder = os.path.join(target_folder, item, "rgb")
if os.path.exists(input_item_folder) and os.path.exists(target_item_folder):
psnr = psnr_compute(
input_item_folder, target_item_folder, mask_item_folder, opt.pad
)
if psnr == -1:
continue
psnr_list.append(psnr)
results_dict[item] = psnr
if opt.debug:
break
print(results_dict)
results_dict["all_mean"] = np.mean(psnr_list)
print(save_folder)
print(results_dict)
write_json(os.path.join(save_folder, "PSNR.json"), results_dict)