Documenting model.py
Browse files
model.py
CHANGED
@@ -1,11 +1,15 @@
|
|
1 |
-
|
2 |
-
import
|
|
|
|
|
|
|
3 |
import inspect
|
4 |
|
5 |
|
6 |
_CAP = 3501 # Cap for the number of notes
|
7 |
|
8 |
class Encoder_Z(tf.keras.layers.Layer):
|
|
|
9 |
|
10 |
def __init__(self, dim_z, name="encoder", **kwargs):
|
11 |
super(Encoder_Z, self).__init__(name=name, **kwargs)
|
@@ -31,6 +35,7 @@ class Encoder_Z(tf.keras.layers.Layer):
|
|
31 |
|
32 |
|
33 |
class Decoder_X(tf.keras.layers.Layer):
|
|
|
34 |
|
35 |
def __init__(self, dim_z, name="decoder", **kwargs):
|
36 |
super(Decoder_X, self).__init__(name=name, **kwargs)
|
@@ -64,12 +69,14 @@ kl_weight = tf.keras.backend.variable(0.125)
|
|
64 |
|
65 |
|
66 |
class VAECost:
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
73 |
|
74 |
def __init__(self, model):
|
75 |
self.model = model
|
@@ -113,6 +120,7 @@ class VAECost:
|
|
113 |
|
114 |
|
115 |
class VAE(tf.keras.Model):
|
|
|
116 |
|
117 |
def __init__(self, name="variational autoencoder", **kwargs):
|
118 |
super(VAE, self).__init__(name=name, **kwargs)
|
@@ -147,17 +155,37 @@ class VAE(tf.keras.Model):
|
|
147 |
"mean recons": mean_recons_error,
|
148 |
"kl weight": kl_weight}
|
149 |
|
150 |
-
def encode(self, x_input):
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
155 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
156 |
z_sample = mu + sd * tf.random.normal(shape=(120,))
|
157 |
return z_sample, mu, sd
|
158 |
|
159 |
-
def generate(self, z_sample=None):
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
if z_sample == None:
|
163 |
z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
|
|
|
1 |
+
# Deep learning
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
# Methods for loading the weights into the model
|
5 |
+
import os
|
6 |
import inspect
|
7 |
|
8 |
|
9 |
_CAP = 3501 # Cap for the number of notes
|
10 |
|
11 |
class Encoder_Z(tf.keras.layers.Layer):
|
12 |
+
# Encoder part of the VAE
|
13 |
|
14 |
def __init__(self, dim_z, name="encoder", **kwargs):
|
15 |
super(Encoder_Z, self).__init__(name=name, **kwargs)
|
|
|
35 |
|
36 |
|
37 |
class Decoder_X(tf.keras.layers.Layer):
|
38 |
+
# Decoder part of the VAE.
|
39 |
|
40 |
def __init__(self, dim_z, name="decoder", **kwargs):
|
41 |
super(Decoder_X, self).__init__(name=name, **kwargs)
|
|
|
69 |
|
70 |
|
71 |
class VAECost:
|
72 |
+
"""
|
73 |
+
VAE cost with a schedule based on the Microsoft Research Blog's article
|
74 |
+
"Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
|
75 |
+
|
76 |
+
The KL weight increases linearly, until it meets a certain threshold and keeps constant
|
77 |
+
for the same number of epochs. After that, it decreases abruptly to zero again, and the
|
78 |
+
cycle repeats.
|
79 |
+
"""
|
80 |
|
81 |
def __init__(self, model):
|
82 |
self.model = model
|
|
|
120 |
|
121 |
|
122 |
class VAE(tf.keras.Model):
|
123 |
+
# Main architecture, which connects the encoder with the decoder.
|
124 |
|
125 |
def __init__(self, name="variational autoencoder", **kwargs):
|
126 |
super(VAE, self).__init__(name=name, **kwargs)
|
|
|
155 |
"mean recons": mean_recons_error,
|
156 |
"kl weight": kl_weight}
|
157 |
|
158 |
+
def encode(self, x_input: tf.Tensor) -> tuple[tf.Tensor]:
|
159 |
+
"""
|
160 |
+
Get a "song map" and make a forward pass through the encoder, in order
|
161 |
+
to return the latent representation and the distribution's parameters.
|
162 |
+
|
163 |
+
Parameters:
|
164 |
+
x_input (tf.Tensor): Song map to be encoded by the VAE.
|
165 |
+
|
166 |
+
Returns:
|
167 |
+
tf.Tensor: The parameters of the distribution which encode the song
|
168 |
+
(mu, sd) and a sampled latent representation from this
|
169 |
+
distribution (z_sample).
|
170 |
+
"""
|
171 |
|
172 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
173 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
174 |
z_sample = mu + sd * tf.random.normal(shape=(120,))
|
175 |
return z_sample, mu, sd
|
176 |
|
177 |
+
def generate(self, z_sample: tf.Tensor=None) -> tf.Tensor:
|
178 |
+
"""
|
179 |
+
Decode a latent representation of a song.
|
180 |
+
|
181 |
+
Parameters:
|
182 |
+
z_sample (tf.Tensor): Song encoding outputed by the encoder. If
|
183 |
+
None, this sampling is done over an
|
184 |
+
unit Gaussian distribution.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
tf.Tensor: Song map corresponding to the encoding.
|
188 |
+
"""
|
189 |
|
190 |
if z_sample == None:
|
191 |
z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
|