# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- from functools import partial from typing import Optional, Tuple import numpy as np import torch from .image_util import get_tv_resample_method, resize_max_res def inter_distances(tensors: torch.Tensor): """ To calculate the distance between each two depth maps. """ distances = [] for i, j in torch.combinations(torch.arange(tensors.shape[0])): arr1 = tensors[i : i + 1] arr2 = tensors[j : j + 1] distances.append(arr1 - arr2) dist = torch.concatenate(distances, dim=0) return dist def ensemble_depth( depth: torch.Tensor, scale_invariant: bool = True, shift_invariant: bool = True, output_uncertainty: bool = False, reduction: str = "median", regularizer_strength: float = 0.02, max_iter: int = 2, tol: float = 1e-3, max_res: int = 1024, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The alignment happens when the predictions have one or more degrees of freedom, that is when they are either affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) alignment is skipped and only ensembling is performed. Args: depth (`torch.Tensor`): Input ensemble depth maps. scale_invariant (`bool`, *optional*, defaults to `True`): Whether to treat predictions as scale-invariant. shift_invariant (`bool`, *optional*, defaults to `True`): Whether to treat predictions as shift-invariant. output_uncertainty (`bool`, *optional*, defaults to `False`): Whether to output uncertainty map. reduction (`str`, *optional*, defaults to `"median"`): Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and `"median"`. regularizer_strength (`float`, *optional*, defaults to `0.02`): Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. max_iter (`int`, *optional*, defaults to `2`): Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` argument. tol (`float`, *optional*, defaults to `1e-3`): Alignment solver tolerance. The solver stops when the tolerance is reached. max_res (`int`, *optional*, defaults to `1024`): Resolution at which the alignment is performed; `None` matches the `processing_resolution`. Returns: A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: `(1, 1, H, W)`. """ if depth.dim() != 4 or depth.shape[1] != 1: raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") if reduction not in ("mean", "median"): raise ValueError(f"Unrecognized reduction method: {reduction}.") if not scale_invariant and shift_invariant: raise ValueError("Pure shift-invariant ensembling is not supported.") def init_param(depth: torch.Tensor): init_min = depth.reshape(ensemble_size, -1).min(dim=1).values init_max = depth.reshape(ensemble_size, -1).max(dim=1).values if scale_invariant and shift_invariant: init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) init_t = -init_s * init_min param = torch.cat((init_s, init_t)).cpu().numpy() elif scale_invariant: init_s = 1.0 / init_max.clamp(min=1e-6) param = init_s.cpu().numpy() else: raise ValueError("Unrecognized alignment.") return param def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: if scale_invariant and shift_invariant: s, t = np.split(param, 2) s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) out = depth * s + t elif scale_invariant: s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) out = depth * s else: raise ValueError("Unrecognized alignment.") return out def ensemble( depth_aligned: torch.Tensor, return_uncertainty: bool = False ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: uncertainty = None if reduction == "mean": prediction = torch.mean(depth_aligned, dim=0, keepdim=True) if return_uncertainty: uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) elif reduction == "median": prediction = torch.median(depth_aligned, dim=0, keepdim=True).values if return_uncertainty: uncertainty = torch.median( torch.abs(depth_aligned - prediction), dim=0, keepdim=True ).values else: raise ValueError(f"Unrecognized reduction method: {reduction}.") return prediction, uncertainty def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: cost = 0.0 depth_aligned = align(depth, param) for i, j in torch.combinations(torch.arange(ensemble_size)): diff = depth_aligned[i] - depth_aligned[j] cost += (diff**2).mean().sqrt().item() if regularizer_strength > 0: prediction, _ = ensemble(depth_aligned, return_uncertainty=False) err_near = (0.0 - prediction.min()).abs().item() err_far = (1.0 - prediction.max()).abs().item() cost += (err_near + err_far) * regularizer_strength return cost def compute_param(depth: torch.Tensor): import scipy depth_to_align = depth.to(torch.float32) if max_res is not None and max(depth_to_align.shape[2:]) > max_res: try: depth_to_align = resize_max_res( depth_to_align, max_res, get_tv_resample_method("nearest-exact") ) except: depth_to_align = resize_max_res( depth_to_align, max_res, get_tv_resample_method("bilinear") ) param = init_param(depth_to_align) res = scipy.optimize.minimize( partial(cost_fn, depth=depth_to_align), param, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}, ) return res.x requires_aligning = scale_invariant or shift_invariant ensemble_size = depth.shape[0] if requires_aligning: param = compute_param(depth) depth = align(depth, param) depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) depth_max = depth.max() if scale_invariant and shift_invariant: depth_min = depth.min() elif scale_invariant: depth_min = 0 else: raise ValueError("Unrecognized alignment.") depth_range = (depth_max - depth_min).clamp(min=1e-6) depth = (depth - depth_min) / depth_range if output_uncertainty: uncertainty /= depth_range return depth, uncertainty # [1,1,H,W], [1,1,H,W]