# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import re import importlib import torch from argparse import Namespace import numpy as np from PIL import Image import os import argparse import dill as pickle def save_obj(obj, name): with open(name, "wb") as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) def load_obj(name): with open(name, "rb") as f: return pickle.load(f) def copyconf(default_opt, **kwargs): conf = argparse.Namespace(**vars(default_opt)) for key in kwargs: print(key, kwargs[key]) setattr(conf, key, kwargs[key]) return conf # Converts a Tensor into a Numpy array # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): if isinstance(image_tensor, list): image_numpy = [] for i in range(len(image_tensor)): image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) return image_numpy if image_tensor.dim() == 4: # transform each image in the batch images_np = [] for b in range(image_tensor.size(0)): one_image = image_tensor[b] one_image_np = tensor2im(one_image) images_np.append(one_image_np.reshape(1, *one_image_np.shape)) images_np = np.concatenate(images_np, axis=0) return images_np if image_tensor.dim() == 2: image_tensor = image_tensor.unsqueeze(0) image_numpy = image_tensor.detach().cpu().float().numpy() if normalize: image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 else: image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 image_numpy = np.clip(image_numpy, 0, 255) if image_numpy.shape[2] == 1: image_numpy = image_numpy[:, :, 0] return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): if label_tensor.dim() == 4: # transform each image in the batch images_np = [] for b in range(label_tensor.size(0)): one_image = label_tensor[b] one_image_np = tensor2label(one_image, n_label, imtype) images_np.append(one_image_np.reshape(1, *one_image_np.shape)) images_np = np.concatenate(images_np, axis=0) # if tile: # images_tiled = tile_images(images_np) # return images_tiled # else: # images_np = images_np[0] # return images_np return images_np if label_tensor.dim() == 1: return np.zeros((64, 64, 3), dtype=np.uint8) if n_label == 0: return tensor2im(label_tensor, imtype) label_tensor = label_tensor.cpu().float() if label_tensor.size()[0] > 1: label_tensor = label_tensor.max(0, keepdim=True)[1] label_tensor = Colorize(n_label)(label_tensor) label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) result = label_numpy.astype(imtype) return result def save_image(image_numpy, image_path, create_dir=False): if create_dir: os.makedirs(os.path.dirname(image_path), exist_ok=True) if len(image_numpy.shape) == 2: image_numpy = np.expand_dims(image_numpy, axis=2) if image_numpy.shape[2] == 1: image_numpy = np.repeat(image_numpy, 3, 2) image_pil = Image.fromarray(image_numpy) # save to png image_pil.save(image_path.replace(".jpg", ".png")) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): """ alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) """ return [atoi(c) for c in re.split("(\d+)", text)] def natural_sort(items): items.sort(key=natural_keys) def str2bool(v): if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def find_class_in_module(target_cls_name, module): target_cls_name = target_cls_name.replace("_", "").lower() clslib = importlib.import_module(module) cls = None for name, clsobj in clslib.__dict__.items(): if name.lower() == target_cls_name: cls = clsobj if cls is None: print( "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name) ) exit(0) return cls def save_network(net, label, epoch, opt): save_filename = "%s_net_%s.pth" % (epoch, label) save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) torch.save(net.cpu().state_dict(), save_path) if len(opt.gpu_ids) and torch.cuda.is_available(): net.cuda() def load_network(net, label, epoch, opt): save_filename = "%s_net_%s.pth" % (epoch, label) save_dir = os.path.join(opt.checkpoints_dir, opt.name) save_path = os.path.join(save_dir, save_filename) if os.path.exists(save_path): weights = torch.load(save_path) net.load_state_dict(weights) return net ############################################################################### # Code from # https://github.com/ycszen/pytorch-seg/blob/master/transform.py # Modified so it complies with the Citscape label map colors ############################################################################### def uint82bin(n, count=8): """returns the binary of integer n, count refers to amount of bits""" return "".join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) class Colorize(object): def __init__(self, n=35): self.cmap = labelcolormap(n) self.cmap = torch.from_numpy(self.cmap[:n]) def __call__(self, gray_image): size = gray_image.size() color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) for label in range(0, len(self.cmap)): mask = (label == gray_image[0]).cpu() color_image[0][mask] = self.cmap[label][0] color_image[1][mask] = self.cmap[label][1] color_image[2][mask] = self.cmap[label][2] return color_image