TomRB22 commited on
Commit
63f1a8f
·
1 Parent(s): e878ec7

Documenting model.py

Browse files
Files changed (1) hide show
  1. model.py +41 -13
model.py CHANGED
@@ -1,11 +1,15 @@
1
- import tensorflow as tf
2
- import os
 
 
 
3
  import inspect
4
 
5
 
6
  _CAP = 3501 # Cap for the number of notes
7
 
8
  class Encoder_Z(tf.keras.layers.Layer):
 
9
 
10
  def __init__(self, dim_z, name="encoder", **kwargs):
11
  super(Encoder_Z, self).__init__(name=name, **kwargs)
@@ -31,6 +35,7 @@ class Encoder_Z(tf.keras.layers.Layer):
31
 
32
 
33
  class Decoder_X(tf.keras.layers.Layer):
 
34
 
35
  def __init__(self, dim_z, name="decoder", **kwargs):
36
  super(Decoder_X, self).__init__(name=name, **kwargs)
@@ -64,12 +69,14 @@ kl_weight = tf.keras.backend.variable(0.125)
64
 
65
 
66
  class VAECost:
67
- # VAE cost with a schedule based on the Microsoft Research Blog's article
68
- # "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
69
- #
70
- # The KL weight increases linearly, until it meets a certain threshold and keeps constant
71
- # for the same number of epochs. After that, it decreases abruptly to zero again, and the
72
- # cycle repeats.
 
 
73
 
74
  def __init__(self, model):
75
  self.model = model
@@ -113,6 +120,7 @@ class VAECost:
113
 
114
 
115
  class VAE(tf.keras.Model):
 
116
 
117
  def __init__(self, name="variational autoencoder", **kwargs):
118
  super(VAE, self).__init__(name=name, **kwargs)
@@ -147,17 +155,37 @@ class VAE(tf.keras.Model):
147
  "mean recons": mean_recons_error,
148
  "kl weight": kl_weight}
149
 
150
- def encode(self, x_input):
151
- # Get a "song map" and make a forward pass through the encoder, in order
152
- # to return the latent representation and the distribution's parameters
 
 
 
 
 
 
 
 
 
 
153
 
154
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
155
  sd = tf.math.log(1 + tf.math.exp(rho))
156
  z_sample = mu + sd * tf.random.normal(shape=(120,))
157
  return z_sample, mu, sd
158
 
159
- def generate(self, z_sample=None):
160
- # Decode a latent representation of a song, which is provided or sampled
 
 
 
 
 
 
 
 
 
 
161
 
162
  if z_sample == None:
163
  z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)
 
1
+ # Deep learning
2
+ import tensorflow as tf
3
+
4
+ # Methods for loading the weights into the model
5
+ import os
6
  import inspect
7
 
8
 
9
  _CAP = 3501 # Cap for the number of notes
10
 
11
  class Encoder_Z(tf.keras.layers.Layer):
12
+ # Encoder part of the VAE
13
 
14
  def __init__(self, dim_z, name="encoder", **kwargs):
15
  super(Encoder_Z, self).__init__(name=name, **kwargs)
 
35
 
36
 
37
  class Decoder_X(tf.keras.layers.Layer):
38
+ # Decoder part of the VAE.
39
 
40
  def __init__(self, dim_z, name="decoder", **kwargs):
41
  super(Decoder_X, self).__init__(name=name, **kwargs)
 
69
 
70
 
71
  class VAECost:
72
+ """
73
+ VAE cost with a schedule based on the Microsoft Research Blog's article
74
+ "Less pain, more gain: A simple method for VAE training with less of that KL-vanishing agony"
75
+
76
+ The KL weight increases linearly, until it meets a certain threshold and keeps constant
77
+ for the same number of epochs. After that, it decreases abruptly to zero again, and the
78
+ cycle repeats.
79
+ """
80
 
81
  def __init__(self, model):
82
  self.model = model
 
120
 
121
 
122
  class VAE(tf.keras.Model):
123
+ # Main architecture, which connects the encoder with the decoder.
124
 
125
  def __init__(self, name="variational autoencoder", **kwargs):
126
  super(VAE, self).__init__(name=name, **kwargs)
 
155
  "mean recons": mean_recons_error,
156
  "kl weight": kl_weight}
157
 
158
+ def encode(self, x_input: tf.Tensor) -> tuple[tf.Tensor]:
159
+ """
160
+ Get a "song map" and make a forward pass through the encoder, in order
161
+ to return the latent representation and the distribution's parameters.
162
+
163
+ Parameters:
164
+ x_input (tf.Tensor): Song map to be encoded by the VAE.
165
+
166
+ Returns:
167
+ tf.Tensor: The parameters of the distribution which encode the song
168
+ (mu, sd) and a sampled latent representation from this
169
+ distribution (z_sample).
170
+ """
171
 
172
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
173
  sd = tf.math.log(1 + tf.math.exp(rho))
174
  z_sample = mu + sd * tf.random.normal(shape=(120,))
175
  return z_sample, mu, sd
176
 
177
+ def generate(self, z_sample: tf.Tensor=None) -> tf.Tensor:
178
+ """
179
+ Decode a latent representation of a song.
180
+
181
+ Parameters:
182
+ z_sample (tf.Tensor): Song encoding outputed by the encoder. If
183
+ None, this sampling is done over an
184
+ unit Gaussian distribution.
185
+
186
+ Returns:
187
+ tf.Tensor: Song map corresponding to the encoding.
188
+ """
189
 
190
  if z_sample == None:
191
  z_sample = tf.expand_dims(tf.random.normal(shape=(120,)), axis=0)