vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
""" Adapted from https://github.com/SongweiGe/TATS"""
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import warnings
import torch
import imageio
import math
import numpy as np
import sys
import pdb as pdb_original
# import SimpleITK as sitk
import logging
import imageio.core.util
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
class ForkedPdb(pdb_original.Pdb):
"""A Pdb subclass that may be used
from a forked multiprocessing child
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open('/dev/stdin')
pdb_original.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
# Shifts src_tf dim to dest dim
# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
n_dims = len(x.shape)
if src_dim < 0:
src_dim = n_dims + src_dim
if dest_dim < 0:
dest_dim = n_dims + dest_dim
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
dims = list(range(n_dims))
del dims[src_dim]
permutation = []
ctr = 0
for i in range(n_dims):
if i == dest_dim:
permutation.append(src_dim)
else:
permutation.append(dims[ctr])
ctr += 1
x = x.permute(permutation)
if make_contiguous:
x = x.contiguous()
return x
# reshapes tensor start from dim i (inclusive)
# to dim j (exclusive) to the desired shape
# e.g. if x.shape = (b, thw, c) then
# view_range(x, 1, 2, (t, h, w)) returns
# x of shape (b, t, h, w, c)
def view_range(x, i, j, shape):
shape = tuple(shape)
n_dims = len(x.shape)
if i < 0:
i = n_dims + i
if j is None:
j = n_dims
elif j < 0:
j = n_dims + j
assert 0 <= i < j <= n_dims
x_shape = x.shape
target_shape = x_shape[:i] + shape + x_shape[j:]
return x.view(target_shape)
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def tensor_slice(x, begin, size):
assert all([b >= 0 for b in begin])
size = [l - b if s == -1 else s
for s, b, l in zip(size, begin, x.shape)]
assert all([s >= 0 for s in size])
slices = [slice(b, b + s) for b, s in zip(begin, size)]
return x[slices]
def adopt_weight(global_step, threshold=0, value=0.):
weight = 1
if global_step < threshold:
weight = value
return weight
def comp_getattr(args, attr_name, default=None):
if hasattr(args, attr_name):
return getattr(args, attr_name)
else:
return default
def visualize_tensors(t, name=None, nest=0):
if name is not None:
print(name, "current nest: ", nest)
print("type: ", type(t))
if 'dict' in str(type(t)):
print(t.keys())
for k in t.keys():
if t[k] is None:
print(k, "None")
else:
if 'Tensor' in str(type(t[k])):
print(k, t[k].shape)
elif 'dict' in str(type(t[k])):
print(k, 'dict')
visualize_tensors(t[k], name, nest + 1)
elif 'list' in str(type(t[k])):
print(k, len(t[k]))
visualize_tensors(t[k], name, nest + 1)
elif 'list' in str(type(t)):
print("list length: ", len(t))
for t2 in t:
visualize_tensors(t2, name, nest + 1)
elif 'Tensor' in str(type(t)):
print(t.shape)
else:
print(t)
return ""