image-deblurring / test /test_metrics.py
balthou's picture
initiate demo
cec5823
import torch
import numpy as np
from rstor.learning.metrics import compute_psnr, compute_ssim, compute_metrics, compute_lpips
from rstor.properties import REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM, DEVICE
from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS
def test_compute_psnr():
# Test case 1: Identical values
predic = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
target = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]])
assert torch.isinf(compute_psnr(predic, target, clamp_mse=0)), "Test case 1 failed"
# Test case 2: Predic and target have different values
predic = torch.tensor([[[[0., 0.], [0., 0.]]]])
target = torch.tensor([[[[0.25, 0.25], [0.25, 0.25]]]])
assert compute_psnr(predic, target).item() == (10. * torch.log10(torch.Tensor([4.**2]))).item() # 12db
print("All tests passed.")
def test_compute_ssim():
x = torch.rand(8, 3, 256, 256)
y = torch.rand(8, 3, 256, 256)
ssim = compute_ssim(x, y, reduction=REDUCTION_AVERAGE)
ssim_per_unit = compute_ssim(x, y, reduction=REDUCTION_SKIP)
assert ssim_per_unit.shape == (8,), "SSIM Test case 1 failed"
assert ssim_per_unit.mean() == ssim, "SSIM Test case 2 failed"
def test_compute_lpips():
for i in range(2):
x = torch.rand(8, 3, 256, 256).to(DEVICE)
y = torch.rand(8, 3, 256, 256).to(DEVICE)
lpips = compute_lpips(x, y, reduction=REDUCTION_AVERAGE)
lpips_per_unit = compute_lpips(x, y, reduction=REDUCTION_SKIP)
assert lpips_per_unit.shape == (8,), "LPIPS Test case 1 failed"
assert torch.isclose(lpips_per_unit.mean(), lpips), "LPIPS Test case 2 failed"
def test_compute_metrics():
x = torch.rand(8, 3, 256, 256) # negative value ensures that we check clamping for LPIPS
y = x.clone() + torch.randn(8, 3, 256, 256) * 0.01
metrics = compute_metrics(x, y)
print(metrics)
metric_per_image = compute_metrics(x, y, reduction=REDUCTION_SKIP)
metric_sum_reduction = compute_metrics(x, y, reduction=REDUCTION_SUM)
assert metric_per_image[METRIC_PSNR].shape == (8,), "Metrics Test case 1 failed"
assert metric_per_image[METRIC_SSIM].shape == (8,), "Metrics Test case 2 failed"
assert metric_per_image[METRIC_LPIPS].shape == (8,), "Metrics Test case 3 failed"
assert np.isclose(metric_per_image[METRIC_PSNR].mean().item(), metrics[METRIC_PSNR]), "Metrics Test case 4 failed"
assert np.isclose(metric_per_image[METRIC_PSNR].sum().item(),
metric_sum_reduction[METRIC_PSNR]), "Metrics Test case 5 failed"
assert np.isclose(metrics[METRIC_PSNR],
metric_sum_reduction[METRIC_PSNR]/8.), "Metrics Test case 6 failed"