File size: 5,108 Bytes
f3daba8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import re
import importlib
import torch
from argparse import Namespace
import numpy as np
from PIL import Image
import os
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
if image_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(image_tensor.size(0)):
one_image = image_tensor[b]
one_image_np = tensor2im(one_image)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
return images_np
if image_tensor.dim() == 2:
image_tensor = image_tensor.unsqueeze(0)
image_numpy = image_tensor.detach().cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1:
image_numpy = image_numpy[:, :, 0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False):
if label_tensor.dim() == 4:
# transform each image in the batch
images_np = []
for b in range(label_tensor.size(0)):
one_image = label_tensor[b]
one_image_np = tensor2label(one_image, n_label, imtype)
images_np.append(one_image_np.reshape(1, *one_image_np.shape))
images_np = np.concatenate(images_np, axis=0)
if tile:
images_tiled = tile_images(images_np)
return images_tiled
else:
images_np = images_np[0]
return images_np
if label_tensor.dim() == 1:
return np.zeros((64, 64, 3), dtype=np.uint8)
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
result = label_numpy.astype(imtype)
return result
def save_image(image_numpy, image_path, create_dir=False):
if create_dir:
os.makedirs(os.path.dirname(image_path), exist_ok=True)
if len(image_numpy.shape) == 2:
image_numpy = np.expand_dims(image_numpy, axis=2)
if image_numpy.shape[2] == 1:
image_numpy = np.repeat(image_numpy, 3, 2)
image_pil = Image.fromarray(image_numpy)
# save to png
image_pil.save(image_path.replace('.jpg', '.png'))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [atoi(c) for c in re.split('(\d+)', text)]
def natural_sort(items):
items.sort(key=natural_keys)
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def find_class_in_module(target_cls_name, module):
target_cls_name = target_cls_name.replace('_', '').lower()
clslib = importlib.import_module(module)
cls = None
for name, clsobj in clslib.__dict__.items():
if name.lower() == target_cls_name:
cls = clsobj
if cls is None:
print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
exit(0)
return cls
def save_network(net, label, epoch, opt):
save_filename = '%s_net_%s.pth' % (epoch, label)
save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
torch.save(net.cpu().state_dict(), save_path)
if len(opt.gpu_ids) and torch.cuda.is_available():
net.cuda()
def load_network(net, label, epoch, opt):
save_filename = '%s_net_%s.pth' % (epoch, label)
save_dir = os.path.join(opt.checkpoints_dir, opt.name)
save_path = os.path.join(save_dir, save_filename)
weights = torch.load(save_path)
net.load_state_dict(weights, strict=False)#
return net
|