Spaces:
Sleeping
Sleeping
import numpy as np | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
tfd = tfp.distributions | |
from easydict import EasyDict as edict | |
from .base import BaseModel | |
from .pc_encoder import * | |
from .flow.transforms import Transform | |
from .cVAE import CondVAE | |
class ACSetVAE(BaseModel): | |
def __init__(self, hps): | |
self.prior_net = LatentEncoder(hps, name='prior') | |
self.posterior_net = LatentEncoder(hps, name='posterior') | |
self.cvae = CondVAE(hps.vae_params, name="cvae") | |
if hps.use_peq_embed: | |
self.peq_embed = SetXformer(hps) | |
super(ACSetVAE, self).__init__(hps) | |
def build_net(self): | |
with tf.compat.v1.variable_scope('acset_vae', reuse=tf.compat.v1.AUTO_REUSE): | |
self.x = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) #[200, 1024] | |
self.b = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) | |
self.m = tf.compat.v1.placeholder(tf.float32, [None, self.hps.set_size, self.hps.dimension]) | |
# build transform | |
self.transform = Transform(edict(self.hps.trans_params)) | |
# prior | |
prior_inputs = tf.concat([self.x*self.b, self.b], axis=-1) # [200, 2048] | |
prior = self.prior_net(prior_inputs) # [256] | |
prior_sample = prior.sample() # [256] | |
prior_sample, _ = self.transform.inverse(prior_sample) # [256] | |
# peq embedding | |
cm = peq_embed = None | |
if self.hps.use_peq_embed: | |
peq_embed = self.peq_embed(prior_inputs) # [200, 256] | |
C = peq_embed.get_shape().as_list()[-1] # 256 | |
cm = tf.reshape(peq_embed, [-1,C]) # [256] | |
# posterior | |
posterior_inputs = tf.concat([self.x*self.m, self.m], axis=-1) # [200, 2048] | |
posterior = self.posterior_net(posterior_inputs) # [256] | |
posterior_sample = posterior.sample() # [256] | |
# kl term | |
z_sample, logdet = self.transform.forward(posterior_sample) # [256], [1] | |
logp = tf.reduce_sum(input_tensor=prior.log_prob(z_sample), axis=1) + logdet # [256] -> [1] | |
kl = tf.reduce_sum(input_tensor=posterior.entropy(), axis=1) + logp # [256] -> [1] | |
# generator | |
x = tf.reshape(self.x, [-1,self.hps.dimension]) # [1024] | |
b = tf.reshape(self.b, [-1,self.hps.dimension]) | |
m = tf.reshape(self.m, [-1,self.hps.dimension]) | |
cv = tf.reshape(tf.tile(tf.expand_dims(posterior_sample, axis=1), [1,self.hps.set_size,1]), [-1,self.hps.latent_dim]) # [256] -> [1, 256] -> [1, 200, 256] -> [256] | |
if not cm is None: | |
c = tf.concat([cv, cm], axis=-1) # [512] | |
else: | |
c = cv | |
# vector-wise posterior | |
vec_kl, vec_post_sample = self.cvae.enc(tf.concat([x, c], axis=-1), c) # [1], [256] | |
vec_kl = tf.reshape(vec_kl, [-1, self.hps.set_size]) # [200] | |
recon_dist = self.cvae.dec(tf.concat([vec_post_sample, c], axis=-1), c) | |
log_likel = recon_dist.log_prob(x) # [1] | |
log_likel = tf.reshape(log_likel, [-1,self.hps.set_size]) # [200] | |
self.set_metric = self.set_elbo = log_likel + tf.expand_dims(kl, axis=1) / self.hps.set_size # [200] | |
log_likel = tf.reduce_mean(input_tensor=log_likel, axis=1) | |
tf.compat.v1.summary.scalar('log_likel', tf.reduce_mean(input_tensor=log_likel)) | |
# elbo | |
self.elbo = log_likel + kl / self.hps.set_size + tf.reduce_mean(input_tensor=vec_kl, axis=1) | |
self.metric = self.elbo | |
self.loss = tf.reduce_mean(input_tensor=-self.elbo) | |
tf.compat.v1.summary.scalar('loss', self.loss) | |
# sample | |
x = tf.reshape(self.x, [-1,self.hps.dimension]) | |
b = tf.reshape(self.b, [-1,self.hps.dimension]) | |
m = tf.reshape(self.m, [-1,self.hps.dimension]) | |
cv = tf.reshape(tf.tile(tf.expand_dims(prior_sample, axis=1), [1,self.hps.set_size,1]), [-1,self.hps.latent_dim]) | |
if not cm is None: | |
c = tf.concat([cv, cm], axis=-1) | |
else: | |
c = cv | |
vec_prior_sample = tf.random.normal(shape=tf.shape(input=vec_post_sample)) | |
sample_dist = self.cvae.dec(tf.concat([vec_prior_sample, c], axis=-1), c) | |
log_likel = sample_dist.log_prob(x) | |
sample = sample_dist.sample() | |
self.sample = tf.reshape(sample, [-1, self.hps.set_size, self.hps.dimension]) | |
# compress | |
self.log_likel = tf.reshape(log_likel, [-1,self.hps.set_size]) + vec_kl | |
# total_params = 0 | |
# for variable in tf.trainable_variables(): | |
# shape = variable.get_shape() | |
# params = 1 | |
# for dim in shape: | |
# params *= dim | |
# total_params += params | |
# print("Total number of parameters: ", total_params) |