from collections import defaultdict from functools import partial import torch import torch.nn.functional as F import torchvision.transforms.v2.functional as TF from PIL import Image from unik3d.utils.chamfer_distance import ChamferDistance from unik3d.utils.constants import DEPTH_BINS chamfer_cls = ChamferDistance() def kl_div(gt, pred, eps: float = 1e-6): depth_bins = DEPTH_BINS.to(gt.device) gt, pred = torch.bucketize( gt, boundaries=depth_bins, out_int32=True ), torch.bucketize(pred, boundaries=depth_bins, out_int32=True) gt = torch.bincount(gt, minlength=len(depth_bins) + 1) pred = torch.bincount(pred, minlength=len(depth_bins) + 1) gt = gt / gt.sum() pred = pred / pred.sum() return torch.sum(gt * (torch.log(gt + eps) - torch.log(pred + eps))) def chamfer_dist(tensor1, tensor2): x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) dist1, dist2, idx1, idx2 = chamfer_cls( tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths ) return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 def auc(tensor1, tensor2, thresholds): x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) dist1, dist2, idx1, idx2 = chamfer_cls( tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths ) # compute precision recall precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] auc_value = torch.trapz( torch.tensor(precisions, device=tensor1.device), torch.tensor(recalls, device=tensor1.device), ) return auc_value def delta(tensor1, tensor2, exponent): inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) return (inlier < 1.25**exponent).to(torch.float32).mean() def rho(tensor1, tensor2): min_deg = 0.5 tensor1_norm = tensor1 / torch.norm(tensor1, dim=-1, p=2, keepdim=True).clip( min=1e-6 ) tensor2_norm = tensor2 / torch.norm(tensor2, dim=-1, p=2, keepdim=True).clip( min=1e-6 ) max_polar_angle = torch.arccos(tensor1_norm[..., 2]).max() * 180.0 / torch.pi if max_polar_angle < 100.0: threshold = 15.0 elif max_polar_angle < 190.0: threshold = 20.0 else: threshold = 30.0 acos_clip = 1 - 1e-6 # inner prod of norm vector -> cosine angular_error = ( torch.arccos( (tensor1_norm * tensor2_norm) .sum(dim=-1) .clip(min=-acos_clip, max=acos_clip) ) * 180.0 / torch.pi ) thresholds = torch.linspace(min_deg, threshold, steps=100, device=tensor1.device) y_values = [ (angular_error.abs() <= th).to(torch.float32).mean() for th in thresholds ] auc_value = torch.trapz( torch.tensor(y_values, device=tensor1.device), thresholds ) / (threshold - min_deg) return auc_value def tau(tensor1, tensor2, perc): inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) return (inlier < (1.0 + perc)).to(torch.float32).mean() @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) def ssi(tensor1, tensor2, qtl=0.05): stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) error = (tensor1 - tensor2).abs() mask = error < torch.quantile(error, 1 - qtl) tensor1_mask = tensor1.to(torch.float32)[mask] tensor2_mask = tensor2.to(torch.float32)[mask] stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) tensor2_one = torch.stack([tensor2_mask, torch.ones_like(tensor2_mask)], dim=1) A = torch.matmul(tensor2_one.T, tensor2_one) + stability_mat det_A = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] A_inv = (1.0 / det_A) * torch.tensor( [[A[1, 1], -A[0, 1]], [-A[1, 0], A[0, 0]]], device=tensor1.device ) b = tensor2_one.T @ tensor1_mask.unsqueeze(1) scale_shift = A_inv @ b scale, shift = scale_shift.squeeze().chunk(2, dim=0) return tensor2 * scale + shift def si(tensor1, tensor2): return tensor2 * torch.median(tensor1) / torch.median(tensor2) def arel(tensor1, tensor2): tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2) return (torch.abs(tensor1 - tensor2) / tensor1).mean() def d_auc(tensor1, tensor2): exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device) deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents] return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 def f1_score(tensor1, tensor2, thresholds): x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) dist1, dist2, idx1, idx2 = chamfer_cls( tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths ) # compute precision recall precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] precisions = torch.tensor(precisions, device=tensor1.device) recalls = torch.tensor(recalls, device=tensor1.device) f1_thresholds = 2 * precisions * recalls / (precisions + recalls) f1_thresholds = torch.where( torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds ) f1_value = torch.trapz(f1_thresholds) / len(thresholds) return f1_value def f1_score_si(tensor1, tensor2, thresholds): tensor2 = ( tensor2 * torch.median(tensor1.norm(dim=-1)) / torch.median(tensor2.norm(dim=-1)) ) f1_value = f1_score(tensor1, tensor2, thresholds) return f1_value DICT_METRICS = { "d1": partial(delta, exponent=1.0), "d2": partial(delta, exponent=2.0), "d3": partial(delta, exponent=3.0), "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()), "rmselog": lambda gt, pred: torch.sqrt( ((torch.log(gt) - torch.log(pred)) ** 2).mean() ), "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(), "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(), "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(), "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(), "medianlog": lambda gt, pred: 100 * (torch.log(pred) - torch.log(gt)).median().abs(), "d_auc": d_auc, "tau": partial(tau, perc=0.03), } DICT_METRICS_3D = { "MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2), "arel_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2) / torch.norm(gt, dim=0, p=2), "tau_3d": lambda gt, pred, thresholds: ( (torch.norm(pred, dim=0, p=2) / torch.norm(gt, dim=0, p=2)).log().abs().exp() < 1.25 ) .float() .mean(), "chamfer": lambda gt, pred, thresholds: chamfer_dist( gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) ), "F1": lambda gt, pred, thresholds: f1_score( gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1), thresholds=thresholds, ), "F1_si": lambda gt, pred, thresholds: f1_score_si( gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1), thresholds=thresholds, ), "rays": lambda gt, pred, thresholds: rho( gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) ), } DICT_METRICS_FLOW = { "epe": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)), "epe1": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 1, "epe3": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 3, "epe5": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 5, } DICT_METRICS_D = { "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( torch.float32 ), "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt), } def eval_depth( gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None ): summary_metrics = defaultdict(list) # preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): if max_depth is not None: mask = mask & (gt <= max_depth) for name, fn in DICT_METRICS.items(): if name in ["tau", "d1", "arel"]: for rescale_fn in ["ssi", "si"]: summary_metrics[f"{name}_{rescale_fn}"].append( fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask])) ) summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} def eval_3d( gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None ): summary_metrics = defaultdict(list) MAX_PIXELS = 75_000 # 300_000 ratio = min(1.0, (MAX_PIXELS / masks[0].sum()) ** 0.5) h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio) gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact") preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact") masks = F.interpolate( masks.float(), size=(h_max, w_max), mode="nearest-exact" ).bool() for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): if not torch.any(mask): continue for name, fn in DICT_METRICS_3D.items(): summary_metrics[name].append( fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() ) return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} def compute_aucs(gt, pred, mask, uncertainties, steps=50, metrics=["abs_rel"]): dict_ = {} x_axis = torch.linspace(0, 1, steps=steps + 1, device=gt.device) quantiles = torch.linspace(0, 1 - 1 / steps, steps=steps, device=gt.device) zer = torch.tensor(0.0, device=gt.device) # revert order (high uncertainty first) uncertainties = -uncertainties[mask] gt = gt[mask] pred = pred[mask] true_uncert = {metric: -DICT_METRICS_D[metric](gt, pred) for metric in metrics} # get percentiles for sampling and corresponding subsets thresholds = torch.quantile(uncertainties, quantiles) subs = [(uncertainties >= t) for t in thresholds] # compute sparsification curves for each metric (add 0 for final sampling) for metric in metrics: opt_thresholds = torch.quantile(true_uncert[metric], quantiles) opt_subs = [(true_uncert[metric] >= t) for t in opt_thresholds] sparse_curve = torch.stack( [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in subs] + [zer], dim=0 ) opt_curve = torch.stack( [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in opt_subs] + [zer], dim=0, ) rnd_curve = DICT_METRICS[metric](gt, pred) dict_[f"AUSE_{metric}"] = torch.trapz(sparse_curve - opt_curve, x=x_axis) dict_[f"AURG_{metric}"] = rnd_curve - torch.trapz(sparse_curve, x=x_axis) return dict_ def eval_depth_uncertainties( gts: torch.Tensor, preds: torch.Tensor, uncertainties: torch.Tensor, masks: torch.Tensor, max_depth=None, ): summary_metrics = defaultdict(list) preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") for i, (gt, pred, mask, uncertainty) in enumerate( zip(gts, preds, masks, uncertainties) ): if max_depth is not None: mask = torch.logical_and(mask, gt < max_depth) for name, fn in DICT_METRICS.items(): summary_metrics[name].append(fn(gt[mask], pred[mask])) for name, val in compute_aucs(gt, pred, mask, uncertainty).items(): summary_metrics[name].append(val) return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} def lazy_eval_depth( gt_fns, pred_fns, min_depth=1e-2, max_depth=None, depth_scale=256.0 ): summary_metrics = defaultdict(list) for i, (gt_fn, pred_fn) in enumerate(zip(gt_fns, pred_fns)): gt = TF.pil_to_tensor(Image.open(gt_fn)).to(torch.float32) / depth_scale pred = TF.pil_to_tensor(Image.open(pred_fn)).to(torch.float32) / depth_scale mask = gt > min_depth if max_depth is not None: mask_2 = gt < max_depth mask = torch.logical_and(mask, mask_2) for name, fn in DICT_METRICS.items(): summary_metrics[name].append(fn(gt[mask], pred[mask])) return {name: torch.mean(vals).item() for name, vals in summary_metrics.items()}