import torch from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing import List, Optional ALL_METRICS = [METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS] def compute_psnr( predic: torch.Tensor, target: torch.Tensor, clamp_mse=1e-10, reduction: Optional[str] = REDUCTION_AVERAGE ) -> torch.Tensor: """ Compute the average PSNR metric for a batch of predicted and true values. Args: predic (torch.Tensor): [N, C, H, W] predicted values. target (torch.Tensor): [N, C, H, W] target values. reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION_SUM. Returns: torch.Tensor: The average PSNR value for the batch. """ with torch.no_grad(): mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1)) mse_per_image = torch.clamp(mse_per_image, min=clamp_mse) psnr_per_image = 10 * torch.log10(1 / mse_per_image) if reduction == REDUCTION_AVERAGE: average_psnr = torch.mean(psnr_per_image) elif reduction == REDUCTION_SUM: average_psnr = torch.sum(psnr_per_image) elif reduction == REDUCTION_SKIP: average_psnr = psnr_per_image else: raise ValueError(f"Unknown reduction {reduction}") return average_psnr def compute_ssim( predic: torch.Tensor, target: torch.Tensor, reduction: Optional[str] = REDUCTION_AVERAGE ) -> torch.Tensor: """ Compute the average SSIM metric for a batch of predicted and true values. Args: predic (torch.Tensor): [N, C, H, W] predicted values. target (torch.Tensor): [N, C, H, W] target values. reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP. Returns: torch.Tensor: The average SSIM value for the batch. """ with torch.no_grad(): reduction_mode = { REDUCTION_SKIP: None, REDUCTION_AVERAGE: "elementwise_mean", REDUCTION_SUM: "sum" }[reduction] ssim = SSIM(data_range=1.0, reduction=reduction_mode).to(predic.device) assert predic.shape == target.shape, f"{predic.shape} != {target.shape}" assert predic.device == target.device, f"{predic.device} != {target.device}" ssim_value = ssim(predic, target) return ssim_value def compute_lpips( predic: torch.Tensor, target: torch.Tensor, reduction: Optional[str] = REDUCTION_AVERAGE, ) -> torch.Tensor: """ Compute the average LPIPS metric for a batch of predicted and true values. https://richzhang.github.io/PerceptualSimilarity/ Args: predic (torch.Tensor): [N, C, H, W] predicted values. target (torch.Tensor): [N, C, H, W] target values. reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP. Returns: torch.Tensor: The average SSIM value for the batch. """ reduction_mode = { REDUCTION_SKIP: "sum", # does not really matter REDUCTION_AVERAGE: "mean", REDUCTION_SUM: "sum" }[reduction] with torch.no_grad(): lpip_metrics = LearnedPerceptualImagePatchSimilarity( reduction=reduction_mode, normalize=True # If set to True will instead expect input to be in the [0,1] range. ).to(predic.device) assert predic.shape == target.shape, f"{predic.shape} != {target.shape}" assert predic.device == target.device, f"{predic.device} != {target.device}" if reduction == REDUCTION_SKIP: lpip_value = [] for idx in range(predic.shape[0]): lpip_value.append(lpip_metrics( predic[idx, ...].unsqueeze(0).clip(0, 1), target[idx, ...].unsqueeze(0).clip(0, 1) )) lpip_value = torch.stack(lpip_value) elif reduction in [REDUCTION_SUM, REDUCTION_AVERAGE]: lpip_value = lpip_metrics(predic.clip(0, 1), target.clip(0, 1)) return lpip_value def compute_metrics( predic: torch.Tensor, target: torch.Tensor, reduction: Optional[str] = REDUCTION_AVERAGE, chosen_metrics: Optional[List[str]] = ALL_METRICS) -> dict: """ Compute the metrics for a batch of predicted and true values. Args: predic (torch.Tensor): [N, C, H, W] predicted values. target (torch.Tensor): [N, C, H, W] target values. reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION SUM. chosen_metrics (list): List of metrics to compute, default [METRIC_PSNR, METRIC_SSIM] Returns: dict: computed metrics. """ metrics = {} if METRIC_PSNR in chosen_metrics: average_psnr = compute_psnr(predic, target, reduction=reduction) metrics[METRIC_PSNR] = average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr if METRIC_SSIM in chosen_metrics: ssim_value = compute_ssim(predic, target, reduction=reduction) metrics[METRIC_SSIM] = ssim_value.item() if reduction != REDUCTION_SKIP else ssim_value if METRIC_LPIPS in chosen_metrics: lpip_value = compute_lpips(predic, target, reduction=reduction) metrics[METRIC_LPIPS] = lpip_value.item() if reduction != REDUCTION_SKIP else lpip_value return metrics