Spaces:
Running
Running
import torch | |
import numpy as np | |
from rstor.learning.metrics import compute_psnr, compute_ssim, compute_metrics, compute_lpips | |
from rstor.properties import REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM, DEVICE | |
from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS | |
def test_compute_psnr(): | |
# Test case 1: Identical values | |
predic = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) | |
target = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]]) | |
assert torch.isinf(compute_psnr(predic, target, clamp_mse=0)), "Test case 1 failed" | |
# Test case 2: Predic and target have different values | |
predic = torch.tensor([[[[0., 0.], [0., 0.]]]]) | |
target = torch.tensor([[[[0.25, 0.25], [0.25, 0.25]]]]) | |
assert compute_psnr(predic, target).item() == (10. * torch.log10(torch.Tensor([4.**2]))).item() # 12db | |
print("All tests passed.") | |
def test_compute_ssim(): | |
x = torch.rand(8, 3, 256, 256) | |
y = torch.rand(8, 3, 256, 256) | |
ssim = compute_ssim(x, y, reduction=REDUCTION_AVERAGE) | |
ssim_per_unit = compute_ssim(x, y, reduction=REDUCTION_SKIP) | |
assert ssim_per_unit.shape == (8,), "SSIM Test case 1 failed" | |
assert ssim_per_unit.mean() == ssim, "SSIM Test case 2 failed" | |
def test_compute_lpips(): | |
for i in range(2): | |
x = torch.rand(8, 3, 256, 256).to(DEVICE) | |
y = torch.rand(8, 3, 256, 256).to(DEVICE) | |
lpips = compute_lpips(x, y, reduction=REDUCTION_AVERAGE) | |
lpips_per_unit = compute_lpips(x, y, reduction=REDUCTION_SKIP) | |
assert lpips_per_unit.shape == (8,), "LPIPS Test case 1 failed" | |
assert torch.isclose(lpips_per_unit.mean(), lpips), "LPIPS Test case 2 failed" | |
def test_compute_metrics(): | |
x = torch.rand(8, 3, 256, 256) # negative value ensures that we check clamping for LPIPS | |
y = x.clone() + torch.randn(8, 3, 256, 256) * 0.01 | |
metrics = compute_metrics(x, y) | |
print(metrics) | |
metric_per_image = compute_metrics(x, y, reduction=REDUCTION_SKIP) | |
metric_sum_reduction = compute_metrics(x, y, reduction=REDUCTION_SUM) | |
assert metric_per_image[METRIC_PSNR].shape == (8,), "Metrics Test case 1 failed" | |
assert metric_per_image[METRIC_SSIM].shape == (8,), "Metrics Test case 2 failed" | |
assert metric_per_image[METRIC_LPIPS].shape == (8,), "Metrics Test case 3 failed" | |
assert np.isclose(metric_per_image[METRIC_PSNR].mean().item(), metrics[METRIC_PSNR]), "Metrics Test case 4 failed" | |
assert np.isclose(metric_per_image[METRIC_PSNR].sum().item(), | |
metric_sum_reduction[METRIC_PSNR]), "Metrics Test case 5 failed" | |
assert np.isclose(metrics[METRIC_PSNR], | |
metric_sum_reduction[METRIC_PSNR]/8.), "Metrics Test case 6 failed" | |