File size: 3,823 Bytes
f3daba8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
import copy
import os
import numpy as np
from sklearn import svm

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"


def linear_interpolate(latent_code, boundary, start_distance=-3, end_distance=3, steps=10):
    """Manipulates the given latent code with respect to a particular boundary.

    Basically, this function takes a latent code and a boundary as inputs, and
    outputs a collection of manipulated latent codes. For example, let `steps` to
    be 10, then the input `latent_code` is with shape [1, latent_space_dim], input
    `boundary` is with shape [1, latent_space_dim] and unit norm, the output is
    with shape [10, latent_space_dim]. The first output latent code is
    `start_distance` away from the given `boundary`, while the last output latent
    code is `end_distance` away from the given `boundary`. Remaining latent codes
    are linearly interpolated.

    Input `latent_code` can also be with shape [1, num_layers, latent_space_dim]
    to support W+ space in Style GAN. In this case, all features in W+ space will
    be manipulated same as each other. Accordingly, the output will be with shape
    [10, num_layers, latent_space_dim].

    NOTE: Distance is sign sensitive.

    Args:
        latent_code: The input latent code for manipulation.
        boundary: The semantic boundary as reference.
        start_distance: The distance to the boundary where the manipulation starts.
        (default: -3.0)
        end_distance: The distance to the boundary where the manipulation ends.
        (default: 3.0)
        steps: Number of steps to move the latent code from start position to end
        position. (default: 10)
    """
    assert latent_code.shape[0] == 1 and boundary.shape[0] == 1 and len(boundary.shape) == 2 and boundary.shape[1] == latent_code.shape[-1]

    linspace = np.linspace(start_distance, end_distance, steps)
    if len(latent_code.shape) == 2:
        linspace = linspace - latent_code.dot(boundary.T)
        linspace = linspace.reshape(-1, 1).astype(np.float32)
        return latent_code + linspace * boundary
    if len(latent_code.shape) == 3:
        linspace = linspace.reshape(-1, 1, 1).astype(np.float32)
        return latent_code + linspace * boundary.reshape(1, 1, -1)
    raise ValueError(
        f"Input `latent_code` should be with shape "
        f"[1, latent_space_dim] or [1, N, latent_space_dim] for "
        f"W+ space in Style GAN!\n"
        f"But {latent_code.shape} is received."
    )


def get_code(domain, boundaries):
    if domain == "ink":
        domain = 0
    elif domain == "monet":
        domain = 1
    elif domain == "vangogh":
        domain = 2
    elif domain == "water":
        domain = 3

    res = np.array(torch.randn(1, 256, dtype=torch.float32))
    # res = linear_interpolate(res, boundaries[domain], end_distance=3, steps=3)[-1:]
    res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
    return res


def modify_code(code, boundaries, domain, range):
    if domain == "ink":
        domain = 0
    elif domain == "monet":
        domain = 1
    elif domain == "vangogh":
        domain = 2
    elif domain == "water":
        domain = 3
    # print(domain, range)
    if range == 0:
        return code
    else:
        res = np.array(code.cpu().detach().numpy())
        res = linear_interpolate(res, boundaries[domain], end_distance=range, steps=3)[-1:]
        res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
        return res


def load_boundries():
    domains = ["ink", "monet", "vangogh", "water"]
    domains.sort()
    boundaries = [
        np.load(os.path.join(os.path.dirname(__file__), "boundaries_amp_52/artwork_" + domain + "_boundary/boundary.npy")) for domain in domains
    ]
    return boundaries