|
""" |
|
Copyright (C) 2019 NVIDIA Corporation. All rights reserved. |
|
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). |
|
""" |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import transforms |
|
|
|
|
|
image_size = 256 |
|
|
|
|
|
values = [12, 2, 6, 8, 1, 10, 3, 14, 11, 4, 5, 13, 9] |
|
values = np.array(values) |
|
|
|
|
|
colors = [ |
|
(135, 206, 235), |
|
(155, 118, 83), |
|
(176, 212, 155), |
|
(90, 188, 216), |
|
(193, 190, 186), |
|
(90, 77, 65), |
|
(86, 125, 70), |
|
(66, 105, 47), |
|
(21, 119, 190), |
|
(58, 46, 39), |
|
(77, 65, 90), |
|
(253, 218, 22), |
|
(208, 204, 204), |
|
] |
|
colors = np.array(colors) |
|
|
|
|
|
def remap_label(arr): |
|
|
|
arr_r = arr[:, :, 0] |
|
|
|
|
|
for i in range(len(colors)): |
|
arr_r[arr_r == colors[i][0]] = values[i] |
|
|
|
arr_r[arr_r > 15] = 15 |
|
return arr_r |
|
|
|
|
|
preprocess = transforms.Compose( |
|
[ |
|
transforms.Resize([image_size, image_size]), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
|
|
def image_loader(loader, label_inp): |
|
image = Image.fromarray(label_inp).convert("RGB") |
|
image = image.resize((image_size, image_size)) |
|
image = loader(image).float() * 255 |
|
image = image.clone().detach().requires_grad_(True) |
|
image = image.unsqueeze(0) |
|
return image |
|
|
|
|
|
def tensor2im(image_tensor): |
|
image_numpy = image_tensor[0].detach().cpu().float().numpy() |
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 |
|
image_numpy = np.clip(image_numpy, 0, 255) |
|
return Image.fromarray(image_numpy.astype(np.uint8)) |
|
|
|
|
|
def get_artwork(model, data, code): |
|
label_inp = remap_label(np.array(data)) |
|
label_inp = (image_loader(preprocess, label_inp)).detach().half() |
|
|
|
image_out = model(label_inp, mode="inference", style_codes=code) |
|
image_out = tensor2im(image_out) |
|
return image_out |
|
|