import os import numpy as np import torch from PIL import Image from skimage.metrics import ( mean_squared_error, peak_signal_noise_ratio, structural_similarity as structural_similarity_index_measure, normalized_mutual_information, ) from tqdm.auto import tqdm from concurrent.futures import ThreadPoolExecutor # Process images to numpy arrays def convert_image_pair_to_numpy(image1, image2): assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) image1_np = np.array(image1) image2_np = np.array(image2) assert image1_np.shape == image2_np.shape return image1_np, image2_np # Compute MSE between two images def compute_mse(image1, image2): image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) return float(mean_squared_error(image1_np, image2_np)) # Compute PSNR between two images def compute_psnr(image1, image2): image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) return float(peak_signal_noise_ratio(image1_np, image2_np)) # Compute SSIM between two images def compute_ssim(image1, image2): image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) return float( structural_similarity_index_measure(image1_np, image2_np, channel_axis=2) ) # Compute NMI between two images def compute_nmi(image1, image2): image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) return float(normalized_mutual_information(image1_np, image2_np)) # Compute metrics def compute_metric_repeated( images1, images2, metric_func, num_workers=None, verbose=False ): # Accept list of PIL images assert isinstance(images1, list) and isinstance(images1[0], Image.Image) assert isinstance(images2, list) and isinstance(images2[0], Image.Image) assert len(images1) == len(images2) if num_workers is not None: assert 1 <= num_workers <= os.cpu_count() else: num_workers = max(torch.cuda.device_count() * 4, 8) metric_name = metric_func.__name__.split("_")[1].upper() with ThreadPoolExecutor(max_workers=num_workers) as executor: tasks = executor.map(metric_func, images1, images2) values = ( list(tasks) if not verbose else list( tqdm( tasks, total=len(images1), desc=f"{metric_name} ", ) ) ) return values # Compute MSE between pairs of images def compute_mse_repeated(images1, images2, num_workers=None, verbose=False): return compute_metric_repeated(images1, images2, compute_mse, num_workers, verbose) # Compute PSNR between pairs of images def compute_psnr_repeated(images1, images2, num_workers=None, verbose=False): return compute_metric_repeated(images1, images2, compute_psnr, num_workers, verbose) # Compute SSIM between pairs of images def compute_ssim_repeated(images1, images2, num_workers=None, verbose=False): return compute_metric_repeated(images1, images2, compute_ssim, num_workers, verbose) # Compute NMI between pairs of images def compute_nmi_repeated(images1, images2, num_workers=None, verbose=False): return compute_metric_repeated(images1, images2, compute_nmi, num_workers, verbose) def compute_image_distance_repeated( images1, images2, metric_name, num_workers=None, verbose=False ): metric_func = { "psnr": compute_psnr, "ssim": compute_ssim, "nmi": compute_nmi, }[metric_name] return compute_metric_repeated(images1, images2, metric_func, num_workers, verbose)