|
|
|
import tensorflow as tf |
|
|
|
|
|
import os |
|
import inspect |
|
|
|
|
|
_CAP = 3501 |
|
|
|
class Encoder_Z(tf.keras.layers.Layer): |
|
|
|
|
|
def __init__(self, dim_z, name="encoder", **kwargs): |
|
super(Encoder_Z, self).__init__(name=name, **kwargs) |
|
self.dim_x = (3, _CAP, 1) |
|
self.dim_z = dim_z |
|
|
|
def build(self): |
|
layers = [tf.keras.layers.InputLayer(input_shape=self.dim_x)] |
|
|
|
layers.append(tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2))) |
|
layers.append(tf.keras.layers.ReLU()) |
|
layers.append(tf.keras.layers.Flatten()) |
|
|
|
layers.append(tf.keras.layers.Dense(2000)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense(500)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense(self.dim_z * 2, activation=None, name="dist_params")) |
|
|
|
return tf.keras.Sequential(layers) |
|
|
|
|
|
class Decoder_X(tf.keras.layers.Layer): |
|
|
|
|
|
def __init__(self, dim_z, name="decoder", **kwargs): |
|
super(Decoder_X, self).__init__(name=name, **kwargs) |
|
self.dim_z = dim_z |
|
|
|
def build(self): |
|
|
|
|
|
layers = [tf.keras.layers.InputLayer(input_shape=(self.dim_z,))] |
|
|
|
layers.append(tf.keras.layers.Dense(500)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense(2000)) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Dense((_CAP - 1) / 2 * 32, activation=None)) |
|
layers.append(tf.keras.layers.Reshape((1, int((_CAP - 1) / 2), 32))) |
|
|
|
layers.append(tf.keras.layers.Conv2DTranspose( |
|
filters=64, kernel_size=3, strides=2, padding='valid')) |
|
layers.append(tf.keras.layers.ReLU()) |
|
|
|
layers.append(tf.keras.layers.Conv2DTranspose( |
|
filters=1, kernel_size=3, strides=1, padding='same')) |
|
|
|
return tf.keras.Sequential(layers) |
|
|
|
kl_weight = tf.keras.backend.variable(0.125) |
|
|
|
|
|
|
|
class VAECost: |
|
""" |
|
VAE cost with a schedule based on the Microsoft Research Blog's article |
|
"Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony" |
|
|
|
The KL weight increases linearly, until it meets a certain threshold and keeps constant |
|
for the same number of epochs. After that, it decreases abruptly to zero again, and the |
|
cycle repeats. |
|
""" |
|
|
|
def __init__(self, model): |
|
self.model = model |
|
self.kl_weight_increasing = True |
|
self.epoch = 1 |
|
|
|
|
|
|
|
|
|
|
|
@tf.function() |
|
def __call__(self, x_true): |
|
x_true = tf.cast(x_true, tf.float32) |
|
|
|
|
|
|
|
z_sample, mu, sd = self.model.encode(x_true) |
|
|
|
|
|
|
|
x_recons = self.model.decoder(z_sample) |
|
|
|
|
|
|
|
|
|
recons_error = tf.cast( |
|
tf.reduce_mean((x_true - x_recons) ** 2, axis=[1, 2, 3]), |
|
tf.float32) |
|
|
|
|
|
kl_divergence = -0.5 * tf.math.reduce_sum( |
|
1 + tf.math.log(tf.math.square(sd)) - tf.math.square(mu) - tf.math.square(sd), |
|
axis=1) |
|
|
|
|
|
elbo = tf.reduce_mean(-kl_weight * kl_divergence - recons_error) |
|
mean_kl_divergence = tf.reduce_mean(kl_divergence) |
|
mean_recons_error = tf.reduce_mean(recons_error) |
|
|
|
return -elbo, mean_kl_divergence, mean_recons_error |
|
|
|
|
|
class VAE(tf.keras.Model): |
|
|
|
|
|
def __init__(self, name="variational autoencoder", **kwargs): |
|
super(VAE, self).__init__(name=name, **kwargs) |
|
self.dim_x = (3, _CAP, 1) |
|
self.encoder = Encoder_Z(dim_z=120).build() |
|
self.decoder = Decoder_X(dim_z=120).build() |
|
self.cost_func = VAECost(self) |
|
|
|
|
|
script_path = inspect.getfile(inspect.currentframe()) |
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(script_path)) |
|
|
|
|
|
weights_dir = os.path.join(script_dir, 'weights') + os.sep |
|
|
|
|
|
self.load_weights(weights_dir) |
|
|
|
@tf.function() |
|
def train_step(self, data): |
|
|
|
|
|
with tf.GradientTape() as tape: |
|
neg_elbo, mean_kl_divergence, mean_recons_error = self.cost_func(data) |
|
|
|
gradients = tape.gradient(neg_elbo, self.trainable_variables) |
|
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) |
|
|
|
return {"abs ELBO": neg_elbo, "mean KL": mean_kl_divergence, |
|
"mean recons": mean_recons_error, |
|
"kl weight": kl_weight} |
|
|
|
def encode(self, x_input: tf.Tensor) -> tuple[tf.Tensor]: |
|
""" |
|
Make a forward pass through the encoder for a given song map, in order |
|
to return the latent representation and the distribution's parameters. |
|
|
|
Parameters |
|
---------- |
|
x_input : tf.Tensor |
|
Song map to be encoded by the VAE. |
|
|
|
Returns |
|
------- |
|
z_sample: tf.Tensor |
|
A sampled latent representation from the distribution which encodes the song. |
|
mu: tf.Tensor |
|
The mean parameter of the distribution. |
|
sd: tf.Tensor |
|
The standard deviation parameter of the distribution. |
|
""" |
|
x_input = tf.expand_dims(x_input, axis=-1) |
|
|
|
if tf.rank(x_input) == 3: |
|
x_input = tf.expand_dims(x_input, axis=0) |
|
|
|
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1) |
|
sd = tf.math.log(1 + tf.math.exp(rho)) |
|
z_sample = mu + sd * tf.random.normal(shape=(120,)) |
|
return z_sample, mu, sd |
|
|
|
def decode(self, z_sample: tf.Tensor=None) -> tf.Tensor: |
|
""" |
|
Decode a latent representation of a song. |
|
|
|
Parameters |
|
---------- |
|
z_sample : tf.Tensor |
|
|
|
Song encoding outputed by the encoder. |
|
Default ``None``, for which the sampling is done over an unit Gaussian distribution. |
|
|
|
Returns |
|
------- |
|
song_map: tf.Tensor |
|
Song map corresponding to the encoding. |
|
""" |
|
|
|
if z_sample == None: |
|
z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0) |
|
|
|
song_map = self.decoder(z_sample) |
|
return song_map |
|
|