Fixed dimensionality issues in ``encode``
Browse files
model.py
CHANGED
@@ -174,7 +174,11 @@ class VAE(tf.keras.Model):
|
|
174 |
sd: tf.Tensor
|
175 |
The standard deviation parameter of the distribution.
|
176 |
"""
|
177 |
-
|
|
|
|
|
|
|
|
|
178 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
179 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
180 |
z_sample = mu + sd * tf.random.normal(shape=(120,))
|
|
|
174 |
sd: tf.Tensor
|
175 |
The standard deviation parameter of the distribution.
|
176 |
"""
|
177 |
+
x_input = tf.expand_dims(x_input, axis=-1) # Add channel dimension
|
178 |
+
|
179 |
+
if tf.rank(x_input) == 3: # If there's no batch dimension, add it
|
180 |
+
tf.expand_dims(x_input, axis=0)
|
181 |
+
|
182 |
mu, rho = tf.split(self.encoder(x_input), num_or_size_splits=2, axis=1)
|
183 |
sd = tf.math.log(1 + tf.math.exp(rho))
|
184 |
z_sample = mu + sd * tf.random.normal(shape=(120,))
|