balthou's picture
initiate demo
cec5823
import torch
from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing import List, Optional
ALL_METRICS = [METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS]
def compute_psnr(
predic: torch.Tensor,
target: torch.Tensor,
clamp_mse=1e-10,
reduction: Optional[str] = REDUCTION_AVERAGE
) -> torch.Tensor:
"""
Compute the average PSNR metric for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION_SUM.
Returns:
torch.Tensor: The average PSNR value for the batch.
"""
with torch.no_grad():
mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1))
mse_per_image = torch.clamp(mse_per_image, min=clamp_mse)
psnr_per_image = 10 * torch.log10(1 / mse_per_image)
if reduction == REDUCTION_AVERAGE:
average_psnr = torch.mean(psnr_per_image)
elif reduction == REDUCTION_SUM:
average_psnr = torch.sum(psnr_per_image)
elif reduction == REDUCTION_SKIP:
average_psnr = psnr_per_image
else:
raise ValueError(f"Unknown reduction {reduction}")
return average_psnr
def compute_ssim(
predic: torch.Tensor,
target: torch.Tensor,
reduction: Optional[str] = REDUCTION_AVERAGE
) -> torch.Tensor:
"""
Compute the average SSIM metric for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
Returns:
torch.Tensor: The average SSIM value for the batch.
"""
with torch.no_grad():
reduction_mode = {
REDUCTION_SKIP: None,
REDUCTION_AVERAGE: "elementwise_mean",
REDUCTION_SUM: "sum"
}[reduction]
ssim = SSIM(data_range=1.0, reduction=reduction_mode).to(predic.device)
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
assert predic.device == target.device, f"{predic.device} != {target.device}"
ssim_value = ssim(predic, target)
return ssim_value
def compute_lpips(
predic: torch.Tensor,
target: torch.Tensor,
reduction: Optional[str] = REDUCTION_AVERAGE,
) -> torch.Tensor:
"""
Compute the average LPIPS metric for a batch of predicted and true values.
https://richzhang.github.io/PerceptualSimilarity/
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
Returns:
torch.Tensor: The average SSIM value for the batch.
"""
reduction_mode = {
REDUCTION_SKIP: "sum", # does not really matter
REDUCTION_AVERAGE: "mean",
REDUCTION_SUM: "sum"
}[reduction]
with torch.no_grad():
lpip_metrics = LearnedPerceptualImagePatchSimilarity(
reduction=reduction_mode,
normalize=True # If set to True will instead expect input to be in the [0,1] range.
).to(predic.device)
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
assert predic.device == target.device, f"{predic.device} != {target.device}"
if reduction == REDUCTION_SKIP:
lpip_value = []
for idx in range(predic.shape[0]):
lpip_value.append(lpip_metrics(
predic[idx, ...].unsqueeze(0).clip(0, 1),
target[idx, ...].unsqueeze(0).clip(0, 1)
))
lpip_value = torch.stack(lpip_value)
elif reduction in [REDUCTION_SUM, REDUCTION_AVERAGE]:
lpip_value = lpip_metrics(predic.clip(0, 1), target.clip(0, 1))
return lpip_value
def compute_metrics(
predic: torch.Tensor,
target: torch.Tensor,
reduction: Optional[str] = REDUCTION_AVERAGE,
chosen_metrics: Optional[List[str]] = ALL_METRICS) -> dict:
"""
Compute the metrics for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION SUM.
chosen_metrics (list): List of metrics to compute, default [METRIC_PSNR, METRIC_SSIM]
Returns:
dict: computed metrics.
"""
metrics = {}
if METRIC_PSNR in chosen_metrics:
average_psnr = compute_psnr(predic, target, reduction=reduction)
metrics[METRIC_PSNR] = average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr
if METRIC_SSIM in chosen_metrics:
ssim_value = compute_ssim(predic, target, reduction=reduction)
metrics[METRIC_SSIM] = ssim_value.item() if reduction != REDUCTION_SKIP else ssim_value
if METRIC_LPIPS in chosen_metrics:
lpip_value = compute_lpips(predic, target, reduction=reduction)
metrics[METRIC_LPIPS] = lpip_value.item() if reduction != REDUCTION_SKIP else lpip_value
return metrics