""" Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import re import importlib import torch from argparse import Namespace import numpy as np from PIL import Image import os # Converts a Tensor into a Numpy array # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): if isinstance(image_tensor, list): image_numpy = [] for i in range(len(image_tensor)): image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) return image_numpy if image_tensor.dim() == 4: # transform each image in the batch images_np = [] for b in range(image_tensor.size(0)): one_image = image_tensor[b] one_image_np = tensor2im(one_image) images_np.append(one_image_np.reshape(1, *one_image_np.shape)) images_np = np.concatenate(images_np, axis=0) return images_np if image_tensor.dim() == 2: image_tensor = image_tensor.unsqueeze(0) image_numpy = image_tensor.detach().cpu().float().numpy() if normalize: image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 else: image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 image_numpy = np.clip(image_numpy, 0, 255) if image_numpy.shape[2] == 1: image_numpy = image_numpy[:, :, 0] return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): if label_tensor.dim() == 4: # transform each image in the batch images_np = [] for b in range(label_tensor.size(0)): one_image = label_tensor[b] one_image_np = tensor2label(one_image, n_label, imtype) images_np.append(one_image_np.reshape(1, *one_image_np.shape)) images_np = np.concatenate(images_np, axis=0) if tile: images_tiled = tile_images(images_np) return images_tiled else: images_np = images_np[0] return images_np if label_tensor.dim() == 1: return np.zeros((64, 64, 3), dtype=np.uint8) if n_label == 0: return tensor2im(label_tensor, imtype) label_tensor = label_tensor.cpu().float() if label_tensor.size()[0] > 1: label_tensor = label_tensor.max(0, keepdim=True)[1] label_tensor = Colorize(n_label)(label_tensor) label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) result = label_numpy.astype(imtype) return result def save_image(image_numpy, image_path, create_dir=False): if create_dir: os.makedirs(os.path.dirname(image_path), exist_ok=True) if len(image_numpy.shape) == 2: image_numpy = np.expand_dims(image_numpy, axis=2) if image_numpy.shape[2] == 1: image_numpy = np.repeat(image_numpy, 3, 2) image_pil = Image.fromarray(image_numpy) # save to png image_pil.save(image_path.replace('.jpg', '.png')) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): ''' alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) ''' return [atoi(c) for c in re.split('(\d+)', text)] def natural_sort(items): items.sort(key=natural_keys) def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def find_class_in_module(target_cls_name, module): target_cls_name = target_cls_name.replace('_', '').lower() clslib = importlib.import_module(module) cls = None for name, clsobj in clslib.__dict__.items(): if name.lower() == target_cls_name: cls = clsobj if cls is None: print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) exit(0) return cls def save_network(net, label, epoch, opt): save_filename = '%s_net_%s.pth' % (epoch, label) save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) torch.save(net.cpu().state_dict(), save_path) if len(opt.gpu_ids) and torch.cuda.is_available(): net.cuda() def load_network(net, label, epoch, opt): save_filename = '%s_net_%s.pth' % (epoch, label) save_dir = os.path.join(opt.checkpoints_dir, opt.name) save_path = os.path.join(save_dir, save_filename) weights = torch.load(save_path) net.load_state_dict(weights, strict=False)# return net