|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluation for NYU depth. |
|
|
|
At evaluation time the ground truth is cropped and clipped. Values outside of |
|
the test crop or clipping range are not included in eval calculations. |
|
|
|
In this evaluator, it is assume that the groud truth is already cropped, so the |
|
entire image is evaluated. However, the evaluator does perform the clipping. |
|
|
|
Reference implementations: |
|
https://github.com/zhyever/Monocular-Depth-Estimation-Toolbox/blo(internal link)a0f341244260ff61541191a613dd74bc/depth/datasets/nyu.py |
|
https://github.com/vinvino02/GLPDepth/blob/7f3c78df4ecd6e7c79fd0c4b73c95d61f4aa2121/code/utils/metrics.py |
|
https://github.com/shariqfarooq123/AdaBins/blob/2fb686a66a304f0a719bc53d77412460af97fd61/evaluate.py |
|
""" |
|
|
|
import functools |
|
|
|
import big_vision.evaluators.proj.uvim.common as common |
|
import big_vision.pp.builder as pp_builder |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
EVAL_CROP_H = 426 |
|
EVAL_CROP_W = 560 |
|
|
|
|
|
class Evaluator: |
|
"""Evaluator for NYU depth.""" |
|
|
|
def __init__(self, |
|
predict_fn, |
|
pp_fn, |
|
batch_size, |
|
dataset, |
|
split, |
|
min_depth=1e-3, |
|
max_depth=10, |
|
dataset_dir=None, |
|
predict_kwargs=None): |
|
self.min_depth = min_depth |
|
self.max_depth = max_depth |
|
|
|
def predict(params, batch): |
|
pred = predict_fn(params, batch, **(predict_kwargs or {})) |
|
|
|
return jax.lax.all_gather({ |
|
"mask": batch["mask"], |
|
"gt": jnp.squeeze(batch["ground_truth"], axis=-1), |
|
"y": pred["depth"], |
|
}, axis_name="data", axis=0) |
|
|
|
self.predict_fn = jax.pmap(predict, axis_name="data") |
|
|
|
|
|
|
|
def preprocess(example): |
|
return { |
|
"mask": tf.constant(1), |
|
**pp_builder.get_preprocess_fn(pp_fn)(example), |
|
} |
|
|
|
self.process_batch_size = batch_size // jax.process_count() |
|
|
|
self.data = common.get_jax_process_dataset( |
|
dataset=dataset, |
|
dataset_dir=dataset_dir, |
|
split=split, |
|
global_batch_size=batch_size, |
|
pp_fn=preprocess) |
|
|
|
def run(self, params): |
|
"""Run eval.""" |
|
|
|
eval_mask = np.ones((EVAL_CROP_H, EVAL_CROP_W), dtype=np.bool_) |
|
rmses = [] |
|
abs_res = [] |
|
abs_logs = [] |
|
d1s = [] |
|
d2s = [] |
|
d3s = [] |
|
for batch in self.data.as_numpy_iterator(): |
|
|
|
out = self.predict_fn(params, batch) |
|
|
|
if jax.process_index(): |
|
continue |
|
|
|
|
|
|
|
out = jax.tree_map(lambda x: jax.device_get(x[0]), out) |
|
|
|
out = jax.tree_map(lambda x: x[out["mask"] == 1], out) |
|
|
|
for gt, pred in zip(out["gt"], out["y"]): |
|
pred = _resize_nearest(pred, (EVAL_CROP_H, EVAL_CROP_W)) |
|
valid_mask = np.logical_and(gt > self.min_depth, gt < self.max_depth) |
|
valid_mask = np.logical_and(valid_mask, eval_mask) |
|
|
|
rmses.append(_compute_rmse(gt[valid_mask], pred[valid_mask])) |
|
abs_res.append(_compute_abs_re(gt[valid_mask], pred[valid_mask])) |
|
abs_logs.append(_compute_abs_log(gt[valid_mask], pred[valid_mask])) |
|
d1s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=1)) |
|
d2s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=2)) |
|
d3s.append(_compute_delta(gt[valid_mask], pred[valid_mask], order=3)) |
|
|
|
if jax.process_index(): |
|
return |
|
|
|
yield "RMSE", np.mean(rmses) |
|
yield "abs_RE", np.mean(abs_res) |
|
yield "log10", np.mean(abs_logs) |
|
yield "delta1", np.mean(d1s) |
|
yield "delta2", np.mean(d2s) |
|
yield "delta3", np.mean(d3s) |
|
|
|
|
|
@functools.partial(jax.jit, static_argnums=(1,), backend="cpu") |
|
def _resize_nearest(image, shape): |
|
return jax.image.resize(image, shape, "nearest") |
|
|
|
|
|
def _compute_rmse(gt, pred): |
|
diff = gt - pred |
|
return np.sqrt(np.mean(np.power(diff, 2))) |
|
|
|
|
|
def _compute_abs_re(gt, pred): |
|
diff = np.abs(gt - pred) |
|
return np.mean(diff / gt) |
|
|
|
|
|
def _compute_abs_log(gt, pred): |
|
diff = np.abs(np.log10(gt) - np.log10(pred)) |
|
return np.mean(diff) |
|
|
|
|
|
def _compute_delta(gt, pred, order): |
|
rel_diff = np.maximum(gt / pred, pred / gt) |
|
return np.sum(rel_diff < 1.25**order) / rel_diff.size |
|
|