File size: 5,207 Bytes
92ac48c
 
 
523a819
92ac48c
 
 
 
 
523a819
92ac48c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523a819
 
92ac48c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import tensorflow as tf


_CAP = 3501 # Cap for the number of notes

class Encoder_Z(tf.keras.layers.Layer):

  def __init__(self, dim_z, name="encoder", **kwargs):
    super(Encoder_Z, self).__init__(name=name, **kwargs)
    self.dim_x = (3, _CAP, 1)
    self.dim_z = dim_z

  def build(self):
    layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)]

    layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2)))
    layers.append(tf.keras.layers.ReLU())
    layers.append(tf.keras.layers.Flatten())

    layers.append(tf.keras.layers.Dense(2000))
    layers.append(tf.keras.layers.ReLU())

    layers.append(tf.keras.layers.Dense(500))
    layers.append(tf.keras.layers.ReLU())

    layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params"))

    return tf.keras.Sequential(layers)


class Decoder_X(tf.keras.layers.Layer):

  def __init__(self, dim_z, name="decoder", **kwargs):
    super(Decoder_X, self).__init__(name=name, **kwargs)
    self.dim_z = dim_z

  def build(self):
    # Build architecture
      
    layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))]

    layers.append(tf.keras.layers.Dense(500))
    layers.append(tf.keras.layers.ReLU())

    layers.append(tf.keras.layers.Dense(2000))
    layers.append(tf.keras.layers.ReLU())

    layers.append(tf.keras.layers.Dense((_CAP - 1) / 2 * 32, activation=None))
    layers.append(tf.keras.layers.Reshape((1, int((_CAP - 1) / 2), 32)))

    layers.append(tf.keras.layers.Conv2DTranspose(
        filters=64, kernel_size=3, strides=2, padding='valid'))
    layers.append(tf.keras.layers.ReLU())

    layers.append(tf.keras.layers.Conv2DTranspose(
        filters=1, kernel_size=3, strides=1, padding='same'))

    return tf.keras.Sequential(layers)

kl_weight = tf.keras.backend.variable(0.125)



class VAECost:
    # VAE cost with a schedule based on the Microsoft Research Blog's article
    # "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
    #
    # The KL weight increases linearly, until it meets a certain threshold and keeps constant
    # for the same number of epochs. After that, it decreases abruptly to zero again, and the
    # cycle repeats.

  def __init__(self, model):
    self.model = model
    self.kl_weight_increasing = True
    self.epoch = 1


  # The loss should have the form loss(y_true, y_pred), but in this
  # case y_pred is computed in the cost function

  @tf.function()
  def __call__(self, x_true):
    x_true = tf.cast(x_true, tf.float32)
      
    # Encode "song map" to get its latent representation and the parameters
    # of the distribution
    z_sample, mu, sd = self.model.encode(x_true)

    # Decode the latent representation. Due to the VAE architecture, we should
    # ideally get a reconstructed song map similar to the input.
    x_recons = self.model.decoder(z_sample)

    # Compute mean squared error, where our ground truth is the song map
    # we pass as input, so we "compare" the reconstruction to it.
      
    recons_error = tf.cast(
        tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]),
        tf.float32)

    # Compute reverse KL divergence
    kl_divergence = -0.5 * tf.math.reduce_sum(
          1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd),
          axis=1) # shape=(batch_size,)

    # Return metrics
    elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error)
    mean_kl_divergence = tf.reduce_mean(kl_divergence)
    mean_recons_error = tf.reduce_mean(recons_error)

    return -elbo, mean_kl_divergence, mean_recons_error


class VAE(tf.keras.Model):

  def __init__(self, dim_z, seed=2000, analytic_kl=True, name="autoencoder", **kwargs):
    super(VAE, self).__init__(name=name, **kwargs)
    self.dim_x = (3, CAP, 1)
    self.dim_z = dim_z
    self.seed = seed
    self.analytic_kl = analytic_kl
    self.encoder = Encoder_Z(dim_z=self.dim_z).build()
    self.decoder = Decoder_X(dim_z=self.dim_z).build()
    self.cost_func = VAECost(self)

  @tf.function()
  def train_step(self, data):
    # Gradient descent
      
    with tf.GradientTape() as tape:
      neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data)

    gradients = tape.gradient(neg_elbo, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

    return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence,
            "mean recons": mean_recons_error,
            "kl weight": kl_weight}

  def encode(self, x_input):
    # Get a "song map" and make a forward pass through the encoder, in order
    # to return the latent representation and the distribution's parameters
      
    mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
    sd = tf.math.log(1 + tf.math.exp(rho))
    z_sample = mu + sd * tf.random.normal(shape=(self.dim_z,))
    return z_sample, mu, sd

  def generate(self, z_sample=None):
    # Decode a latent representation of a song, which is provided or sampled
      
    if z_sample == None:
      z_sample = tf.expand_dims(tf.random.normal(shape=(self.dim_z,)), axis=0)
    return self.decoder(z_sample)