import os import importlib import numpy as np from inspect import isfunction import torch def shape_to_str(x): shape_str = "x".join([str(x) for x in x.shape]) return shape_str def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise ValueError('Boolean value expected.') def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): """ Shifts src_tf dim to dest dim i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) """ n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x def torch_to_np(x): sample = x.detach().cpu() sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) if sample.dim() == 5: sample = sample.permute(0, 2, 3, 4, 1) else: sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous().numpy() return sample def np_to_torch_video(x): x = torch.tensor(x).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] x = (x / 255 - 0.5) * 2 return x def load_npz_from_dir(data_dir): data = [np.load(os.path.join(data_dir, data_name))['arr_0'] for data_name in os.listdir(data_dir)] data = np.concatenate(data, axis=0) return data def load_npz_from_paths(data_paths): data = [np.load(data_path)['arr_0'] for data_path in data_paths] data = np.concatenate(data, axis=0) return data def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x,torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") return total_params def check_istarget(name, para_list): """ name: full name of source para para_list: partial name of target para """ istarget=False for para in para_list: if para in name: return True return istarget