Spaces:
Running
Running
import torch | |
from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM | |
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM | |
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity | |
from typing import List, Optional | |
ALL_METRICS = [METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS] | |
def compute_psnr( | |
predic: torch.Tensor, | |
target: torch.Tensor, | |
clamp_mse=1e-10, | |
reduction: Optional[str] = REDUCTION_AVERAGE | |
) -> torch.Tensor: | |
""" | |
Compute the average PSNR metric for a batch of predicted and true values. | |
Args: | |
predic (torch.Tensor): [N, C, H, W] predicted values. | |
target (torch.Tensor): [N, C, H, W] target values. | |
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION_SUM. | |
Returns: | |
torch.Tensor: The average PSNR value for the batch. | |
""" | |
with torch.no_grad(): | |
mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1)) | |
mse_per_image = torch.clamp(mse_per_image, min=clamp_mse) | |
psnr_per_image = 10 * torch.log10(1 / mse_per_image) | |
if reduction == REDUCTION_AVERAGE: | |
average_psnr = torch.mean(psnr_per_image) | |
elif reduction == REDUCTION_SUM: | |
average_psnr = torch.sum(psnr_per_image) | |
elif reduction == REDUCTION_SKIP: | |
average_psnr = psnr_per_image | |
else: | |
raise ValueError(f"Unknown reduction {reduction}") | |
return average_psnr | |
def compute_ssim( | |
predic: torch.Tensor, | |
target: torch.Tensor, | |
reduction: Optional[str] = REDUCTION_AVERAGE | |
) -> torch.Tensor: | |
""" | |
Compute the average SSIM metric for a batch of predicted and true values. | |
Args: | |
predic (torch.Tensor): [N, C, H, W] predicted values. | |
target (torch.Tensor): [N, C, H, W] target values. | |
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP. | |
Returns: | |
torch.Tensor: The average SSIM value for the batch. | |
""" | |
with torch.no_grad(): | |
reduction_mode = { | |
REDUCTION_SKIP: None, | |
REDUCTION_AVERAGE: "elementwise_mean", | |
REDUCTION_SUM: "sum" | |
}[reduction] | |
ssim = SSIM(data_range=1.0, reduction=reduction_mode).to(predic.device) | |
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}" | |
assert predic.device == target.device, f"{predic.device} != {target.device}" | |
ssim_value = ssim(predic, target) | |
return ssim_value | |
def compute_lpips( | |
predic: torch.Tensor, | |
target: torch.Tensor, | |
reduction: Optional[str] = REDUCTION_AVERAGE, | |
) -> torch.Tensor: | |
""" | |
Compute the average LPIPS metric for a batch of predicted and true values. | |
https://richzhang.github.io/PerceptualSimilarity/ | |
Args: | |
predic (torch.Tensor): [N, C, H, W] predicted values. | |
target (torch.Tensor): [N, C, H, W] target values. | |
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP. | |
Returns: | |
torch.Tensor: The average SSIM value for the batch. | |
""" | |
reduction_mode = { | |
REDUCTION_SKIP: "sum", # does not really matter | |
REDUCTION_AVERAGE: "mean", | |
REDUCTION_SUM: "sum" | |
}[reduction] | |
with torch.no_grad(): | |
lpip_metrics = LearnedPerceptualImagePatchSimilarity( | |
reduction=reduction_mode, | |
normalize=True # If set to True will instead expect input to be in the [0,1] range. | |
).to(predic.device) | |
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}" | |
assert predic.device == target.device, f"{predic.device} != {target.device}" | |
if reduction == REDUCTION_SKIP: | |
lpip_value = [] | |
for idx in range(predic.shape[0]): | |
lpip_value.append(lpip_metrics( | |
predic[idx, ...].unsqueeze(0).clip(0, 1), | |
target[idx, ...].unsqueeze(0).clip(0, 1) | |
)) | |
lpip_value = torch.stack(lpip_value) | |
elif reduction in [REDUCTION_SUM, REDUCTION_AVERAGE]: | |
lpip_value = lpip_metrics(predic.clip(0, 1), target.clip(0, 1)) | |
return lpip_value | |
def compute_metrics( | |
predic: torch.Tensor, | |
target: torch.Tensor, | |
reduction: Optional[str] = REDUCTION_AVERAGE, | |
chosen_metrics: Optional[List[str]] = ALL_METRICS) -> dict: | |
""" | |
Compute the metrics for a batch of predicted and true values. | |
Args: | |
predic (torch.Tensor): [N, C, H, W] predicted values. | |
target (torch.Tensor): [N, C, H, W] target values. | |
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION SUM. | |
chosen_metrics (list): List of metrics to compute, default [METRIC_PSNR, METRIC_SSIM] | |
Returns: | |
dict: computed metrics. | |
""" | |
metrics = {} | |
if METRIC_PSNR in chosen_metrics: | |
average_psnr = compute_psnr(predic, target, reduction=reduction) | |
metrics[METRIC_PSNR] = average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr | |
if METRIC_SSIM in chosen_metrics: | |
ssim_value = compute_ssim(predic, target, reduction=reduction) | |
metrics[METRIC_SSIM] = ssim_value.item() if reduction != REDUCTION_SKIP else ssim_value | |
if METRIC_LPIPS in chosen_metrics: | |
lpip_value = compute_lpips(predic, target, reduction=reduction) | |
metrics[METRIC_LPIPS] = lpip_value.item() if reduction != REDUCTION_SKIP else lpip_value | |
return metrics | |