Spaces:
Running
Running
import numpy as np | |
import tensorflow as tf | |
import tensorflow_probability as tfp | |
tfd = tfp.distributions | |
from .set_transformer import set_transformer | |
class LatentEncoder(object): | |
def __init__(self, hps, name='latent'): | |
self.hps = hps | |
self.name = name | |
def __call__(self, x): | |
''' | |
x: [B,N,C] | |
''' | |
B,N,C = tf.shape(input=x)[0], tf.shape(input=x)[1], *x.get_shape().as_list()[2:] | |
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): | |
x = set_transformer(x, self.hps.latent_encoder_hidden, name='set_xformer') | |
x = tf.reduce_mean(input_tensor=x, axis=1) | |
x = tf.compat.v1.layers.dense(x, self.hps.latent_dim*2, name='d1') | |
x = tf.nn.leaky_relu(x) | |
x = tf.compat.v1.layers.dense(x, self.hps.latent_dim*2, name='d2') | |
m, s = x[...,:self.hps.latent_dim], tf.nn.softplus(x[...,self.hps.latent_dim:]) | |
dist = tfd.Normal(loc=m, scale=s) | |
return dist | |
class SetXformer(object): | |
def __init__(self, hps, name='set_xformer'): | |
self.hps = hps | |
self.name = name | |
def __call__(self, x): | |
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): | |
x = set_transformer(x, self.hps.set_xformer_hids, name='set_xformer') | |
return x |