TomRB22 commited on
Commit
4aac37d
·
1 Parent(s): 6778487

Fixed dimensionality issues in ``encode``

Browse files
Files changed (1) hide show
  1. model.py +5 -1
model.py CHANGED
@@ -174,7 +174,11 @@ class VAE(tf.keras.Model):
174
  sd: tf.Tensor
175
  The standard deviation parameter of the distribution.
176
  """
177
-
 
 
 
 
178
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
179
  sd = tf.math.log(1 + tf.math.exp(rho))
180
  z_sample = mu + sd * tf.random.normal(shape=(120,))
 
174
  sd: tf.Tensor
175
  The standard deviation parameter of the distribution.
176
  """
177
+ x_input = tf.expand_dims(x_input, axis=-1) # Add channel dimension
178
+
179
+ if tf.rank(x_input) == 3: # If there's no batch dimension, add it
180
+ tf.expand_dims(x_input, axis=0)
181
+
182
  mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
183
  sd = tf.math.log(1 + tf.math.exp(rho))
184
  z_sample = mu + sd * tf.random.normal(shape=(120,))