# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # utilitary functions for DUSt3R # -------------------------------------------------------- import numpy as np import torch def todevice(batch, device, callback=None, non_blocking=False): ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). batch: list, tuple, dict of tensors or other things device: pytorch device or 'numpy' callback: function that would be called on every sub-elements. ''' if callback: batch = callback(batch) if isinstance(batch, dict): return {k: todevice(v, device) for k, v in batch.items()} if isinstance(batch, (tuple, list)): return type(batch)(todevice(x, device) for x in batch) x = batch if device == 'numpy': if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() elif x is not None: if isinstance(x, np.ndarray): x = torch.from_numpy(x) if torch.is_tensor(x): x = x.to(device, non_blocking=non_blocking) return x to_device = todevice # alias def to_numpy(x): return todevice(x, 'numpy') def to_cpu(x): return todevice(x, 'cpu') def to_cuda(x): return todevice(x, 'cuda') def collate_with_cat(whatever, lists=False): if isinstance(whatever, dict): return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} elif isinstance(whatever, (tuple, list)): if len(whatever) == 0: return whatever elem = whatever[0] T = type(whatever) if elem is None: return None if isinstance(elem, (bool, float, int, str)): return whatever if isinstance(elem, tuple): return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) if isinstance(elem, dict): return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} if isinstance(elem, torch.Tensor): return listify(whatever) if lists else torch.cat(whatever) if isinstance(elem, np.ndarray): return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) # otherwise, we just chain lists return sum(whatever, T()) def listify(elems): return [x for e in elems for x in e]