File size: 1,937 Bytes
f3daba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba32963
30e21c7
f3daba8
 
 
ba32963
f3daba8
ba32963
 
f3daba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""
import numpy as np
from PIL import Image
from torchvision import transforms

# define constants
image_size = 256

# to label
values = [12, 2, 6, 8, 1, 10, 3, 14, 11, 4, 5, 13, 9]
values = np.array(values)

# from color
colors = [
    (135, 206, 235),
    (155, 118, 83),
    (176, 212, 155),
    (90, 188, 216),
    (193, 190, 186),
    (90, 77, 65),
    (86, 125, 70),
    (66, 105, 47),
    (21, 119, 190),
    (58, 46, 39),
    (77, 65, 90),
    (253, 218, 22),
    (208, 204, 204),
]
colors = np.array(colors)


def remap_label(arr):
    # compare only last color channel to speed up
    arr_b = arr[:, :, 2]

    # remap color to label
    for i in range(len(colors)):
        arr_b[arr_b == colors[i][2]] = values[i]
    # others to 15
    arr_b[arr_b > 15] = 15
    return arr_b


preprocess = transforms.Compose(
    [
        transforms.Resize([image_size, image_size]),
        transforms.ToTensor(),
    ]
)


def image_loader(loader, label_inp):
    image = Image.fromarray(label_inp).convert("RGB")
    image = image.resize((image_size, image_size))
    image = loader(image).float() * 255
    image = image.clone().detach().requires_grad_(True)
    image = image.unsqueeze(0)
    return image


def tensor2im(image_tensor):
    image_numpy = image_tensor[0].detach().cpu().float().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    image_numpy = np.clip(image_numpy, 0, 255)
    return Image.fromarray(image_numpy.astype(np.uint8))


def get_artwork(model, data, code):
    label_inp = remap_label(np.array(data))
    label_inp = (image_loader(preprocess, label_inp)).detach().half()

    image_out = model(label_inp, mode="inference", style_codes=code)
    image_out = tensor2im(image_out)
    return image_out