Spaces:
Sleeping
Sleeping
import numpy as np | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
tfd = tfp.distributions | |
from .networks import dense_nn, cond_dense_nn | |
class CondVAE(object): | |
def __init__(self, hps, name="cvae"): | |
self.hps = hps | |
self.name = name | |
def enc(self, x, cond=None): | |
''' | |
x: [B, C] | |
cond: [B, C] | |
''' | |
B,C = tf.shape(input=x)[0], tf.shape(input=x)[1] | |
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): | |
prior_dist = tfd.MultivariateNormalDiag(tf.zeros(self.hps['hid_dimensions']),tf.ones(self.hps['hid_dimensions'])) | |
if cond is None: | |
x = dense_nn(x, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc") | |
else: | |
x = cond_dense_nn(x, cond, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc") | |
m, s = x[:, :self.hps['hid_dimensions']], tf.nn.softplus(x[:, self.hps['hid_dimensions']:]) | |
posterior_dist = tfd.MultivariateNormalDiag(m,s) | |
#kl = 0.5 * tf.reduce_sum(s + m ** 2 - 1.0 - tf.log(s), axis=-1) | |
kl = - tfd.kl_divergence(posterior_dist, prior_dist) | |
eps = prior_dist.sample(B) | |
posterior_sample = m + eps * s | |
return kl, posterior_sample | |
def dec(self, x, cond=None): | |
''' | |
x: [B, C] | |
''' | |
B,C = tf.shape(input=x)[0], tf.shape(input=x)[1] | |
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): | |
if cond is None: | |
x = dense_nn(x, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec") | |
else: | |
x = cond_dense_nn(x, cond, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec") | |
m, s = x[:, :self.hps['dimension']], tf.nn.softplus(x[:, self.hps['dimension']:]) | |
sample_dist = tfd.MultivariateNormalDiag(loc=m, scale_diag=s) | |
return sample_dist |