Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import torch | |
from PIL import Image | |
from skimage.metrics import ( | |
mean_squared_error, | |
peak_signal_noise_ratio, | |
structural_similarity as structural_similarity_index_measure, | |
normalized_mutual_information, | |
) | |
from tqdm.auto import tqdm | |
from concurrent.futures import ThreadPoolExecutor | |
# Process images to numpy arrays | |
def convert_image_pair_to_numpy(image1, image2): | |
assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) | |
image1_np = np.array(image1) | |
image2_np = np.array(image2) | |
assert image1_np.shape == image2_np.shape | |
return image1_np, image2_np | |
# Compute MSE between two images | |
def compute_mse(image1, image2): | |
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) | |
return float(mean_squared_error(image1_np, image2_np)) | |
# Compute PSNR between two images | |
def compute_psnr(image1, image2): | |
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) | |
return float(peak_signal_noise_ratio(image1_np, image2_np)) | |
# Compute SSIM between two images | |
def compute_ssim(image1, image2): | |
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) | |
return float( | |
structural_similarity_index_measure(image1_np, image2_np, channel_axis=2) | |
) | |
# Compute NMI between two images | |
def compute_nmi(image1, image2): | |
image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) | |
return float(normalized_mutual_information(image1_np, image2_np)) | |
# Compute metrics | |
def compute_metric_repeated( | |
images1, images2, metric_func, num_workers=None, verbose=False | |
): | |
# Accept list of PIL images | |
assert isinstance(images1, list) and isinstance(images1[0], Image.Image) | |
assert isinstance(images2, list) and isinstance(images2[0], Image.Image) | |
assert len(images1) == len(images2) | |
if num_workers is not None: | |
assert 1 <= num_workers <= os.cpu_count() | |
else: | |
num_workers = max(torch.cuda.device_count() * 4, 8) | |
metric_name = metric_func.__name__.split("_")[1].upper() | |
with ThreadPoolExecutor(max_workers=num_workers) as executor: | |
tasks = executor.map(metric_func, images1, images2) | |
values = ( | |
list(tasks) | |
if not verbose | |
else list( | |
tqdm( | |
tasks, | |
total=len(images1), | |
desc=f"{metric_name} ", | |
) | |
) | |
) | |
return values | |
# Compute MSE between pairs of images | |
def compute_mse_repeated(images1, images2, num_workers=None, verbose=False): | |
return compute_metric_repeated(images1, images2, compute_mse, num_workers, verbose) | |
# Compute PSNR between pairs of images | |
def compute_psnr_repeated(images1, images2, num_workers=None, verbose=False): | |
return compute_metric_repeated(images1, images2, compute_psnr, num_workers, verbose) | |
# Compute SSIM between pairs of images | |
def compute_ssim_repeated(images1, images2, num_workers=None, verbose=False): | |
return compute_metric_repeated(images1, images2, compute_ssim, num_workers, verbose) | |
# Compute NMI between pairs of images | |
def compute_nmi_repeated(images1, images2, num_workers=None, verbose=False): | |
return compute_metric_repeated(images1, images2, compute_nmi, num_workers, verbose) | |
def compute_image_distance_repeated( | |
images1, images2, metric_name, num_workers=None, verbose=False | |
): | |
metric_func = { | |
"psnr": compute_psnr, | |
"ssim": compute_ssim, | |
"nmi": compute_nmi, | |
}[metric_name] | |
return compute_metric_repeated(images1, images2, metric_func, num_workers, verbose) | |