File size: 4,057 Bytes
9123ba9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
""" Adapted from https://github.com/SongweiGe/TATS"""
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import warnings
import torch
import imageio
import math
import numpy as np
import sys
import pdb as pdb_original
# import SimpleITK as sitk
import logging
import imageio.core.util
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
class ForkedPdb(pdb_original.Pdb):
"""A Pdb subclass that may be used
from a forked multiprocessing child
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open('/dev/stdin')
pdb_original.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
# Shifts src_tf dim to dest dim
# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
n_dims = len(x.shape)
if src_dim < 0:
src_dim = n_dims + src_dim
if dest_dim < 0:
dest_dim = n_dims + dest_dim
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
dims = list(range(n_dims))
del dims[src_dim]
permutation = []
ctr = 0
for i in range(n_dims):
if i == dest_dim:
permutation.append(src_dim)
else:
permutation.append(dims[ctr])
ctr += 1
x = x.permute(permutation)
if make_contiguous:
x = x.contiguous()
return x
# reshapes tensor start from dim i (inclusive)
# to dim j (exclusive) to the desired shape
# e.g. if x.shape = (b, thw, c) then
# view_range(x, 1, 2, (t, h, w)) returns
# x of shape (b, t, h, w, c)
def view_range(x, i, j, shape):
shape = tuple(shape)
n_dims = len(x.shape)
if i < 0:
i = n_dims + i
if j is None:
j = n_dims
elif j < 0:
j = n_dims + j
assert 0 <= i < j <= n_dims
x_shape = x.shape
target_shape = x_shape[:i] + shape + x_shape[j:]
return x.view(target_shape)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def tensor_slice(x, begin, size):
assert all([b >= 0 for b in begin])
size = [l - b if s == -1 else s
for s, b, l in zip(size, begin, x.shape)]
assert all([s >= 0 for s in size])
slices = [slice(b, b + s) for b, s in zip(begin, size)]
return x[slices]
def adopt_weight(global_step, threshold=0, value=0.):
weight = 1
if global_step < threshold:
weight = value
return weight
def comp_getattr(args, attr_name, default=None):
if hasattr(args, attr_name):
return getattr(args, attr_name)
else:
return default
def visualize_tensors(t, name=None, nest=0):
if name is not None:
print(name, "current nest: ", nest)
print("type: ", type(t))
if 'dict' in str(type(t)):
print(t.keys())
for k in t.keys():
if t[k] is None:
print(k, "None")
else:
if 'Tensor' in str(type(t[k])):
print(k, t[k].shape)
elif 'dict' in str(type(t[k])):
print(k, 'dict')
visualize_tensors(t[k], name, nest + 1)
elif 'list' in str(type(t[k])):
print(k, len(t[k]))
visualize_tensors(t[k], name, nest + 1)
elif 'list' in str(type(t)):
print("list length: ", len(t))
for t2 in t:
visualize_tensors(t2, name, nest + 1)
elif 'Tensor' in str(type(t)):
print(t.shape)
else:
print(t)
return ""
|