""" Adapted from https://github.com/SongweiGe/TATS""" # Copyright (c) Meta Platforms, Inc. All Rights Reserved import warnings import torch import imageio import math import numpy as np import sys import pdb as pdb_original # import SimpleITK as sitk import logging import imageio.core.util logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) def get_single_device(cpu=True): if cpu: return torch.device('cpu') elif torch.cuda.is_available(): return torch.device('cuda') elif torch.xpu.is_available(): return torch.device('xpu') elif torch.mps.is_available(): return torch.device('mps') return None class ForkedPdb(pdb_original.Pdb): """A Pdb subclass that may be used from a forked multiprocessing child """ def interaction(self, *args, **kwargs): _stdin = sys.stdin try: sys.stdin = open('/dev/stdin') pdb_original.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin # Shifts src_tf dim to dest dim # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x # reshapes tensor start from dim i (inclusive) # to dim j (exclusive) to the desired shape # e.g. if x.shape = (b, thw, c) then # view_range(x, 1, 2, (t, h, w)) returns # x of shape (b, t, h, w, c) def view_range(x, i, j, shape): shape = tuple(shape) n_dims = len(x.shape) if i < 0: i = n_dims + i if j is None: j = n_dims elif j < 0: j = n_dims + j assert 0 <= i < j <= n_dims x_shape = x.shape target_shape = x_shape[:i] + shape + x_shape[j:] return x.view(target_shape) def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.reshape(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def tensor_slice(x, begin, size): assert all([b >= 0 for b in begin]) size = [l - b if s == -1 else s for s, b, l in zip(size, begin, x.shape)] assert all([s >= 0 for s in size]) slices = [slice(b, b + s) for b, s in zip(begin, size)] return x[slices] def adopt_weight(global_step, threshold=0, value=0.): weight = 1 if global_step < threshold: weight = value return weight def comp_getattr(args, attr_name, default=None): if hasattr(args, attr_name): return getattr(args, attr_name) else: return default def visualize_tensors(t, name=None, nest=0): if name is not None: print(name, "current nest: ", nest) print("type: ", type(t)) if 'dict' in str(type(t)): print(t.keys()) for k in t.keys(): if t[k] is None: print(k, "None") else: if 'Tensor' in str(type(t[k])): print(k, t[k].shape) elif 'dict' in str(type(t[k])): print(k, 'dict') visualize_tensors(t[k], name, nest + 1) elif 'list' in str(type(t[k])): print(k, len(t[k])) visualize_tensors(t[k], name, nest + 1) elif 'list' in str(type(t)): print("list length: ", len(t)) for t2 in t: visualize_tensors(t2, name, nest + 1) elif 'Tensor' in str(type(t)): print(t.shape) else: print(t) return ""