""" Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). """ import numpy as np from PIL import Image from torchvision import transforms # define constants image_size = 256 # to label values = [12, 2, 6, 8, 1, 10, 3, 14, 11, 4, 5, 13, 9] values = np.array(values) # from color colors = [ (135, 206, 235), (155, 118, 83), (176, 212, 155), (90, 188, 216), (193, 190, 186), (90, 77, 65), (86, 125, 70), (66, 105, 47), (21, 119, 190), (58, 46, 39), (77, 65, 90), (253, 218, 22), (208, 204, 204), ] colors = np.array(colors) def remap_label(arr): # compare only last color channel to speed up arr_b = arr[:, :, 2] # remap color to label for i in range(len(colors)): arr_b[arr_b == colors[i][2]] = values[i] # others to 15 arr_b[arr_b > 15] = 15 return arr_b preprocess = transforms.Compose( [ transforms.Resize([image_size, image_size]), transforms.ToTensor(), ] ) def image_loader(loader, label_inp): image = Image.fromarray(label_inp).convert("RGB") image = image.resize((image_size, image_size)) image = loader(image).float() * 255 image = image.clone().detach().requires_grad_(True) image = image.unsqueeze(0) return image def tensor2im(image_tensor): image_numpy = image_tensor[0].detach().cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 image_numpy = np.clip(image_numpy, 0, 255) return Image.fromarray(image_numpy.astype(np.uint8)) def get_artwork(model, data, code): label_inp = remap_label(np.array(data)) label_inp = (image_loader(preprocess, label_inp)).detach().half() image_out = model(label_inp, mode="inference", style_codes=code) image_out = tensor2im(image_out) return image_out