Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import contextlib | |
import dataclasses | |
import itertools | |
import math | |
import os | |
import unittest | |
import lpips | |
import numpy as np | |
import torch | |
from pytorch3d.implicitron.dataset.frame_data import FrameData | |
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset | |
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch | |
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase | |
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa | |
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa | |
from pytorch3d.implicitron.tools.config import expand_args_fields, registry | |
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth | |
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ | |
from .common_resources import get_skateboard_data, provide_lpips_vgg | |
class TestEvaluation(unittest.TestCase): | |
def setUp(self): | |
# initialize evaluation dataset/dataloader | |
torch.manual_seed(42) | |
stack = contextlib.ExitStack() | |
dataset_root, path_manager = stack.enter_context(get_skateboard_data()) | |
self.addCleanup(stack.close) | |
category = "skateboard" | |
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") | |
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") | |
self.image_size = 64 | |
expand_args_fields(JsonIndexDataset) | |
self.dataset = JsonIndexDataset( | |
frame_annotations_file=frame_file, | |
sequence_annotations_file=sequence_file, | |
dataset_root=dataset_root, | |
image_height=self.image_size, | |
image_width=self.image_size, | |
box_crop=True, | |
remove_empty_masks=False, | |
path_manager=path_manager, | |
) | |
self.bg_color = (0.0, 0.0, 0.0) | |
# init the lpips model for eval | |
provide_lpips_vgg() | |
self.lpips_model = lpips.LPIPS(net="vgg").cuda() | |
def test_eval_depth(self): | |
""" | |
Check that eval_depth correctly masks errors and that, for get_best_scale=True, | |
the error with scaled prediction equals the error without scaling the | |
predicted depth. Finally, test that the error values are as expected | |
for prediction and gt differing by a constant offset. | |
""" | |
gt = (torch.randn(10, 1, 300, 400, device="cuda") * 5.0).clamp(0.0) | |
mask = (torch.rand_like(gt) > 0.5).type_as(gt) | |
for diff in 10 ** torch.linspace(-5, 0, 6): | |
for crop in (0, 5): | |
pred = gt + (torch.rand_like(gt) - 0.5) * 2 * diff | |
# scaled prediction test | |
mse_depth, abs_depth = eval_depth( | |
pred, | |
gt, | |
crop=crop, | |
mask=mask, | |
get_best_scale=True, | |
) | |
mse_depth_scale, abs_depth_scale = eval_depth( | |
pred * 10.0, | |
gt, | |
crop=crop, | |
mask=mask, | |
get_best_scale=True, | |
) | |
self.assertAlmostEqual( | |
float(mse_depth.sum()), float(mse_depth_scale.sum()), delta=1e-4 | |
) | |
self.assertAlmostEqual( | |
float(abs_depth.sum()), float(abs_depth_scale.sum()), delta=1e-4 | |
) | |
# error masking test | |
pred_masked_err = gt + (torch.rand_like(gt) + diff) * (1 - mask) | |
mse_depth_masked, abs_depth_masked = eval_depth( | |
pred_masked_err, | |
gt, | |
crop=crop, | |
mask=mask, | |
get_best_scale=True, | |
) | |
self.assertAlmostEqual( | |
float(mse_depth_masked.sum()), float(0.0), delta=1e-4 | |
) | |
self.assertAlmostEqual( | |
float(abs_depth_masked.sum()), float(0.0), delta=1e-4 | |
) | |
mse_depth_unmasked, abs_depth_unmasked = eval_depth( | |
pred_masked_err, | |
gt, | |
crop=crop, | |
mask=1 - mask, | |
get_best_scale=True, | |
) | |
self.assertGreater( | |
float(mse_depth_unmasked.sum()), | |
float(diff**2), | |
) | |
self.assertGreater( | |
float(abs_depth_unmasked.sum()), | |
float(diff), | |
) | |
# tests with constant error | |
pred_fix_diff = gt + diff * mask | |
for _mask_gt in (mask, None): | |
mse_depth_fix_diff, abs_depth_fix_diff = eval_depth( | |
pred_fix_diff, | |
gt, | |
crop=crop, | |
mask=_mask_gt, | |
get_best_scale=False, | |
) | |
if _mask_gt is not None: | |
expected_err_abs = diff | |
expected_err_mse = diff**2 | |
else: | |
err_mask = (gt > 0.0).float() * mask | |
if crop > 0: | |
err_mask = err_mask[:, :, crop:-crop, crop:-crop] | |
gt_cropped = gt[:, :, crop:-crop, crop:-crop] | |
else: | |
gt_cropped = gt | |
gt_mass = (gt_cropped > 0.0).float().sum(dim=(1, 2, 3)) | |
expected_err_abs = ( | |
diff * err_mask.sum(dim=(1, 2, 3)) / (gt_mass) | |
) | |
expected_err_mse = diff * expected_err_abs | |
self.assertTrue( | |
torch.allclose( | |
abs_depth_fix_diff, | |
expected_err_abs * torch.ones_like(abs_depth_fix_diff), | |
atol=1e-4, | |
) | |
) | |
self.assertTrue( | |
torch.allclose( | |
mse_depth_fix_diff, | |
expected_err_mse * torch.ones_like(mse_depth_fix_diff), | |
atol=1e-4, | |
) | |
) | |
def test_psnr(self): | |
""" | |
Compare against opencv and check that the psnr is above | |
the minimum possible value. | |
""" | |
import cv2 | |
im1 = torch.rand(100, 3, 256, 256).cuda() | |
im1_uint8 = (im1 * 255).to(torch.uint8) | |
im1_rounded = im1_uint8.float() / 255 | |
for max_diff in 10 ** torch.linspace(-5, 0, 6): | |
im2 = im1 + (torch.rand_like(im1) - 0.5) * 2 * max_diff | |
im2 = im2.clamp(0.0, 1.0) | |
im2_uint8 = (im2 * 255).to(torch.uint8) | |
im2_rounded = im2_uint8.float() / 255 | |
# check that our psnr matches the output of opencv | |
psnr = calc_psnr(im1_rounded, im2_rounded) | |
# some versions of cv2 can only take uint8 input | |
psnr_cv2 = cv2.PSNR( | |
im1_uint8.cpu().numpy(), | |
im2_uint8.cpu().numpy(), | |
) | |
self.assertAlmostEqual(float(psnr), float(psnr_cv2), delta=1e-4) | |
# check that all PSNRs are bigger than the minimum possible PSNR | |
max_mse = max_diff**2 | |
min_psnr = 10 * math.log10(1.0 / max_mse) | |
for _im1, _im2 in zip(im1, im2): | |
_psnr = calc_psnr(_im1, _im2) | |
self.assertGreaterEqual(float(_psnr) + 1e-6, min_psnr) | |
def _one_sequence_test( | |
self, | |
seq_dataset, | |
model, | |
batch_indices, | |
check_metrics=False, | |
): | |
loader = torch.utils.data.DataLoader( | |
seq_dataset, | |
shuffle=False, | |
batch_sampler=batch_indices, | |
collate_fn=FrameData.collate, | |
) | |
for frame_data in loader: | |
self.assertIsNone(frame_data.frame_type) | |
self.assertIsNotNone(frame_data.image_rgb) | |
# override the frame_type | |
frame_data.frame_type = [ | |
"train_unseen", | |
*(["train_known"] * (len(frame_data.image_rgb) - 1)), | |
] | |
frame_data = dataclass_to_cuda_(frame_data) | |
preds = model(**dataclasses.asdict(frame_data)) | |
eval_result = eval_batch( | |
frame_data, | |
preds["implicitron_render"], | |
bg_color=self.bg_color, | |
lpips_model=self.lpips_model, | |
) | |
if check_metrics: | |
self._check_metrics( | |
frame_data, preds["implicitron_render"], eval_result | |
) | |
def _check_metrics(self, frame_data, implicitron_render, eval_result): | |
# Make a terribly bad NVS prediction and check that this is worse | |
# than the DBIR prediction. | |
implicitron_render_bad = implicitron_render.clone() | |
implicitron_render_bad.depth_render += ( | |
torch.randn_like(implicitron_render_bad.depth_render) * 100.0 | |
) | |
implicitron_render_bad.image_render += ( | |
torch.randn_like(implicitron_render_bad.image_render) * 100.0 | |
) | |
implicitron_render_bad.mask_render = ( | |
torch.randn_like(implicitron_render_bad.mask_render) > 0.0 | |
).float() | |
eval_result_bad = eval_batch( | |
frame_data, | |
implicitron_render_bad, | |
bg_color=self.bg_color, | |
lpips_model=self.lpips_model, | |
) | |
lower_better = { | |
"psnr_masked": False, | |
"psnr_fg": False, | |
"psnr_full_image": False, | |
"depth_abs_fg": True, | |
"iou": False, | |
"rgb_l1_masked": True, | |
"rgb_l1_fg": True, | |
"lpips_masked": True, | |
"lpips_full_image": True, | |
} | |
for metric in lower_better: | |
m_better = eval_result[metric] | |
m_worse = eval_result_bad[metric] | |
if np.isnan(m_better) or np.isnan(m_worse): | |
continue # metric is missing, i.e. NaN | |
_assert = ( | |
self.assertLessEqual | |
if lower_better[metric] | |
else self.assertGreaterEqual | |
) | |
_assert(m_better, m_worse) | |
def _get_random_batch_indices( | |
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10 | |
): | |
batch_indices = [] | |
for _ in range(n_batches): | |
batch_size = torch.randint( | |
low=min_batch_size, high=max_batch_size, size=(1,) | |
) | |
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size]) | |
return batch_indices | |
def test_full_eval(self, n_sequences=5): | |
"""Test evaluation.""" | |
# caching batch indices first to preserve RNG state | |
seq_datasets = {} | |
batch_indices = {} | |
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences): | |
idx = list(self.dataset.sequence_indices_in_order(seq)) | |
seq_dataset = torch.utils.data.Subset(self.dataset, idx) | |
seq_datasets[seq] = seq_dataset | |
batch_indices[seq] = self._get_random_batch_indices(seq_dataset) | |
for model_class_type in ["ModelDBIR", "GenericModel"]: | |
ModelClass = registry.get(ImplicitronModelBase, model_class_type) | |
expand_args_fields(ModelClass) | |
model = ModelClass( | |
render_image_width=self.image_size, | |
render_image_height=self.image_size, | |
bg_color=self.bg_color, | |
) | |
model.eval() | |
model.cuda() | |
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences): | |
self._one_sequence_test( | |
seq_datasets[seq], | |
model, | |
batch_indices[seq], | |
check_metrics=(model_class_type == "ModelDBIR"), | |
) | |