Spaces:
Running
Running
File size: 5,521 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import torch
from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from typing import List, Optional
ALL_METRICS = [METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS]
def compute_psnr(
predic: torch.Tensor,
target: torch.Tensor,
clamp_mse=1e-10,
reduction: Optional[str] = REDUCTION_AVERAGE
) -> torch.Tensor:
"""
Compute the average PSNR metric for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION_SUM.
Returns:
torch.Tensor: The average PSNR value for the batch.
"""
with torch.no_grad():
mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1))
mse_per_image = torch.clamp(mse_per_image, min=clamp_mse)
psnr_per_image = 10 * torch.log10(1 / mse_per_image)
if reduction == REDUCTION_AVERAGE:
average_psnr = torch.mean(psnr_per_image)
elif reduction == REDUCTION_SUM:
average_psnr = torch.sum(psnr_per_image)
elif reduction == REDUCTION_SKIP:
average_psnr = psnr_per_image
else:
raise ValueError(f"Unknown reduction {reduction}")
return average_psnr
def compute_ssim(
predic: torch.Tensor,
target: torch.Tensor,
reduction: Optional[str] = REDUCTION_AVERAGE
) -> torch.Tensor:
"""
Compute the average SSIM metric for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
Returns:
torch.Tensor: The average SSIM value for the batch.
"""
with torch.no_grad():
reduction_mode = {
REDUCTION_SKIP: None,
REDUCTION_AVERAGE: "elementwise_mean",
REDUCTION_SUM: "sum"
}[reduction]
ssim = SSIM(data_range=1.0, reduction=reduction_mode).to(predic.device)
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
assert predic.device == target.device, f"{predic.device} != {target.device}"
ssim_value = ssim(predic, target)
return ssim_value
def compute_lpips(
predic: torch.Tensor,
target: torch.Tensor,
reduction: Optional[str] = REDUCTION_AVERAGE,
) -> torch.Tensor:
"""
Compute the average LPIPS metric for a batch of predicted and true values.
https://richzhang.github.io/PerceptualSimilarity/
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
Returns:
torch.Tensor: The average SSIM value for the batch.
"""
reduction_mode = {
REDUCTION_SKIP: "sum", # does not really matter
REDUCTION_AVERAGE: "mean",
REDUCTION_SUM: "sum"
}[reduction]
with torch.no_grad():
lpip_metrics = LearnedPerceptualImagePatchSimilarity(
reduction=reduction_mode,
normalize=True # If set to True will instead expect input to be in the [0,1] range.
).to(predic.device)
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
assert predic.device == target.device, f"{predic.device} != {target.device}"
if reduction == REDUCTION_SKIP:
lpip_value = []
for idx in range(predic.shape[0]):
lpip_value.append(lpip_metrics(
predic[idx, ...].unsqueeze(0).clip(0, 1),
target[idx, ...].unsqueeze(0).clip(0, 1)
))
lpip_value = torch.stack(lpip_value)
elif reduction in [REDUCTION_SUM, REDUCTION_AVERAGE]:
lpip_value = lpip_metrics(predic.clip(0, 1), target.clip(0, 1))
return lpip_value
def compute_metrics(
predic: torch.Tensor,
target: torch.Tensor,
reduction: Optional[str] = REDUCTION_AVERAGE,
chosen_metrics: Optional[List[str]] = ALL_METRICS) -> dict:
"""
Compute the metrics for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION SUM.
chosen_metrics (list): List of metrics to compute, default [METRIC_PSNR, METRIC_SSIM]
Returns:
dict: computed metrics.
"""
metrics = {}
if METRIC_PSNR in chosen_metrics:
average_psnr = compute_psnr(predic, target, reduction=reduction)
metrics[METRIC_PSNR] = average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr
if METRIC_SSIM in chosen_metrics:
ssim_value = compute_ssim(predic, target, reduction=reduction)
metrics[METRIC_SSIM] = ssim_value.item() if reduction != REDUCTION_SKIP else ssim_value
if METRIC_LPIPS in chosen_metrics:
lpip_value = compute_lpips(predic, target, reduction=reduction)
metrics[METRIC_LPIPS] = lpip_value.item() if reduction != REDUCTION_SKIP else lpip_value
return metrics
|