|
""" Adapted from https://github.com/SongweiGe/TATS""" |
|
|
|
|
|
import warnings |
|
import torch |
|
import imageio |
|
|
|
import math |
|
import numpy as np |
|
|
|
import sys |
|
import pdb as pdb_original |
|
|
|
import logging |
|
|
|
import imageio.core.util |
|
logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) |
|
|
|
|
|
class ForkedPdb(pdb_original.Pdb): |
|
"""A Pdb subclass that may be used |
|
from a forked multiprocessing child |
|
|
|
""" |
|
|
|
def interaction(self, *args, **kwargs): |
|
_stdin = sys.stdin |
|
try: |
|
sys.stdin = open('/dev/stdin') |
|
pdb_original.Pdb.interaction(self, *args, **kwargs) |
|
finally: |
|
sys.stdin = _stdin |
|
|
|
|
|
|
|
|
|
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): |
|
n_dims = len(x.shape) |
|
if src_dim < 0: |
|
src_dim = n_dims + src_dim |
|
if dest_dim < 0: |
|
dest_dim = n_dims + dest_dim |
|
|
|
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims |
|
|
|
dims = list(range(n_dims)) |
|
del dims[src_dim] |
|
|
|
permutation = [] |
|
ctr = 0 |
|
for i in range(n_dims): |
|
if i == dest_dim: |
|
permutation.append(src_dim) |
|
else: |
|
permutation.append(dims[ctr]) |
|
ctr += 1 |
|
x = x.permute(permutation) |
|
if make_contiguous: |
|
x = x.contiguous() |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def view_range(x, i, j, shape): |
|
shape = tuple(shape) |
|
|
|
n_dims = len(x.shape) |
|
if i < 0: |
|
i = n_dims + i |
|
|
|
if j is None: |
|
j = n_dims |
|
elif j < 0: |
|
j = n_dims + j |
|
|
|
assert 0 <= i < j <= n_dims |
|
|
|
x_shape = x.shape |
|
target_shape = x_shape[:i] + shape + x_shape[j:] |
|
return x.view(target_shape) |
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
with torch.no_grad(): |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) |
|
|
|
res = [] |
|
for k in topk: |
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
return res |
|
|
|
|
|
def tensor_slice(x, begin, size): |
|
assert all([b >= 0 for b in begin]) |
|
size = [l - b if s == -1 else s |
|
for s, b, l in zip(size, begin, x.shape)] |
|
assert all([s >= 0 for s in size]) |
|
|
|
slices = [slice(b, b + s) for b, s in zip(begin, size)] |
|
return x[slices] |
|
|
|
|
|
def adopt_weight(global_step, threshold=0, value=0.): |
|
weight = 1 |
|
if global_step < threshold: |
|
weight = value |
|
return weight |
|
|
|
def comp_getattr(args, attr_name, default=None): |
|
if hasattr(args, attr_name): |
|
return getattr(args, attr_name) |
|
else: |
|
return default |
|
|
|
|
|
def visualize_tensors(t, name=None, nest=0): |
|
if name is not None: |
|
print(name, "current nest: ", nest) |
|
print("type: ", type(t)) |
|
if 'dict' in str(type(t)): |
|
print(t.keys()) |
|
for k in t.keys(): |
|
if t[k] is None: |
|
print(k, "None") |
|
else: |
|
if 'Tensor' in str(type(t[k])): |
|
print(k, t[k].shape) |
|
elif 'dict' in str(type(t[k])): |
|
print(k, 'dict') |
|
visualize_tensors(t[k], name, nest + 1) |
|
elif 'list' in str(type(t[k])): |
|
print(k, len(t[k])) |
|
visualize_tensors(t[k], name, nest + 1) |
|
elif 'list' in str(type(t)): |
|
print("list length: ", len(t)) |
|
for t2 in t: |
|
visualize_tensors(t2, name, nest + 1) |
|
elif 'Tensor' in str(type(t)): |
|
print(t.shape) |
|
else: |
|
print(t) |
|
return "" |
|
|