Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# @Organization : Alibaba XR-Lab | |
# @Author : Lingteng Qiu | |
# @Email : [email protected] | |
# @Time : 2025-03-03 10:28:35 | |
# @Function : Easy to use PSNR metric | |
import os | |
import sys | |
sys.path.append("./") | |
import math | |
import pdb | |
import cv2 | |
import numpy as np | |
import skimage | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
from tqlt import utils as tu | |
def write_json(path, x): | |
"""write a json file. | |
Args: | |
path (str): path to write json file. | |
x (dict): dict to write. | |
""" | |
import json | |
with open(path, "w") as f: | |
json.dump(x, f, indent=2) | |
def img_center_padding(img_np, pad_ratio=0.2, background=1): | |
ori_w, ori_h = img_np.shape[:2] | |
w = round((1 + pad_ratio) * ori_w) | |
h = round((1 + pad_ratio) * ori_h) | |
if background == 1: | |
img_pad_np = np.ones((w, h, 3), dtype=img_np.dtype) | |
else: | |
img_pad_np = np.zeros((w, h, 3), dtype=img_np.dtype) | |
offset_h, offset_w = (w - img_np.shape[0]) // 2, (h - img_np.shape[1]) // 2 | |
img_pad_np[ | |
offset_h : offset_h + img_np.shape[0] :, offset_w : offset_w + img_np.shape[1] | |
] = img_np | |
return img_pad_np, offset_w, offset_h | |
def compute_psnr(src, tar): | |
psnr = skimage.metrics.peak_signal_noise_ratio(tar, src, data_range=1) | |
return psnr | |
def get_parse(): | |
import argparse | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("-f1", "--folder1", required=True, help="input path") | |
parser.add_argument("-f2", "--folder2", required=True, help="output path") | |
parser.add_argument("-m", "--mask", default=None, help="output path") | |
parser.add_argument("--pre", default="anigs") | |
parser.add_argument("--debug", action="store_true") | |
parser.add_argument("--pad", action="store_true") | |
args = parser.parse_args() | |
return args | |
def get_image_paths_current_dir(folder_path): | |
image_extensions = { | |
".jpg", | |
".jpeg", | |
".png", | |
".gif", | |
".bmp", | |
".tiff", | |
".webp", | |
".jfif", | |
} | |
return sorted( | |
[ | |
os.path.join(folder_path, f) | |
for f in os.listdir(folder_path) | |
if os.path.splitext(f)[1].lower() in image_extensions | |
] | |
) | |
def psnr_compute( | |
input_data, | |
results_data, | |
mask_data=None, | |
pad=False, | |
): | |
gt_imgs = get_image_paths_current_dir(input_data) | |
result_imgs = get_image_paths_current_dir(os.path.join(results_data)) | |
if mask_data is not None: | |
mask_imgs = get_image_paths_current_dir(mask_data) | |
else: | |
mask_imgs = None | |
if "visualization" in result_imgs[-1]: | |
result_imgs = result_imgs[:-1] | |
if len(gt_imgs) != len(result_imgs): | |
return -1 | |
gt_imgs = gt_imgs[::4] | |
result_imgs = result_imgs[::4] | |
psnr_mean = [] | |
for mask_i, (gt, result) in tqdm(enumerate(zip(gt_imgs, result_imgs))): | |
result_img = (cv2.imread(result, cv2.IMREAD_UNCHANGED) / 255.0).astype( | |
np.float32 | |
) | |
gt_img = (cv2.imread(gt, cv2.IMREAD_UNCHANGED) / 255.0).astype(np.float32) | |
if mask_imgs is not None: | |
mask_img = ( | |
cv2.imread(mask_imgs[mask_i], cv2.IMREAD_UNCHANGED) / 255.0 | |
).astype(np.float32) | |
mask_img = mask_img[..., -1] | |
mask_img = np.stack([mask_img] * 3, axis=-1) | |
mask_img, _, _ = img_center_padding(mask_img, background=0) | |
if pad: | |
gt_img, _, _ = img_center_padding(gt_img) | |
h, w, c = result_img.shape | |
gt_img = cv2.resize(gt_img, (w, h), interpolation=cv2.INTER_AREA) | |
if mask_imgs is not None: | |
mask_img = cv2.resize(mask_img, (w, h), interpolation=cv2.INTER_AREA) | |
gt_img = gt_img * mask_img + 1 - mask_img | |
result_img = result_img * mask_img + 1 - mask_img | |
mask_label = mask_img[..., 0] | |
psnr_mean += [ | |
compute_psnr(result_img[mask_label > 0.5], gt_img[mask_label > 0.5]) | |
] | |
else: | |
psnr_mean += [compute_psnr(result_img, gt_img)] | |
# Image.fromarray((gt_img * 255).astype(np.uint8)).save("gt.png") | |
# Image.fromarray((result_img * 255).astype(np.uint8)).save("result.png") | |
psnr = np.mean(psnr_mean) | |
return psnr | |
if __name__ == "__main__": | |
opt = get_parse() | |
input_folder = opt.folder1 | |
target_folder = opt.folder2 | |
mask_folder = opt.mask | |
save_folder = os.path.join( | |
f"./exps/metrics{opt.pre}", "psnr_results", "anigs_video" | |
) | |
os.makedirs(save_folder, exist_ok=True) | |
input_folders = tu.next_folders(input_folder) | |
results_dict = dict() | |
psnr_list = [] | |
for input_folder in input_folders: | |
item_basename = tu.basename(input_folder) | |
mask_item_folder = None | |
input_item_folder = os.path.join(input_folder, "rgb") | |
target_item_folder = os.path.join(target_folder, item_basename) | |
if os.path.exists(input_item_folder) and os.path.exists(target_item_folder): | |
psnr = psnr_compute( | |
input_item_folder, target_item_folder, mask_item_folder, opt.pad | |
) | |
if psnr == -1: | |
continue | |
psnr_list.append(psnr) | |
results_dict[item_basename] = psnr | |
if opt.debug: | |
break | |
print(results_dict) | |
results_dict["all_mean"] = np.mean(psnr_list) | |
write_json(os.path.join(save_folder, "PSNR.json"), results_dict) | |