File size: 1,999 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from .networks import dense_nn, cond_dense_nn

class CondVAE(object):
    def __init__(self, hps, name="cvae"):
        self.hps = hps
        self.name = name

    def enc(self, x, cond=None):
        '''
        x: [B, C]
        cond: [B, C]
        '''
        B,C = tf.shape(input=x)[0], tf.shape(input=x)[1]
        with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
            prior_dist = tfd.MultivariateNormalDiag(tf.zeros(self.hps['hid_dimensions']),tf.ones(self.hps['hid_dimensions']))
            if cond is None:
                x = dense_nn(x, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc")
            else:
                x = cond_dense_nn(x, cond, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc")
            m, s = x[:, :self.hps['hid_dimensions']], tf.nn.softplus(x[:, self.hps['hid_dimensions']:])
            posterior_dist = tfd.MultivariateNormalDiag(m,s)
            #kl = 0.5 * tf.reduce_sum(s + m ** 2 - 1.0 - tf.log(s), axis=-1)
            kl = - tfd.kl_divergence(posterior_dist, prior_dist)
            eps = prior_dist.sample(B)
            posterior_sample = m + eps * s
        return kl, posterior_sample
    
    def dec(self, x, cond=None):
        '''
        x: [B, C]
        '''
        B,C = tf.shape(input=x)[0], tf.shape(input=x)[1]
        with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
            if cond is None:
                x = dense_nn(x, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec")
            else:
                x = cond_dense_nn(x, cond, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec")
            m, s = x[:, :self.hps['dimension']], tf.nn.softplus(x[:, self.hps['dimension']:])
            sample_dist = tfd.MultivariateNormalDiag(loc=m, scale_diag=s)
        return sample_dist