import tensorflow as tf tfk = tf.keras import numpy as np # base class class BaseTransform(object): def __init__(self, hps, name='base'): self.name = name self.hps = hps self.build() def build(self): pass def forward(self, x): raise NotImplementedError() def inverse(self, z): raise NotImplementedError() class Transform(BaseTransform): def __init__(self, hps, name='transform'): super(Transform, self).__init__(hps, name) def build(self): self.modules = [] with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): for i, name in enumerate(self.hps.transform): m = TRANS[name](self.hps, f'{i}') self.modules.append(m) def forward(self, x): with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): logdet = 0. for module in self.modules: x, ldet = module.forward(x) logdet = logdet + ldet return x, logdet def inverse(self, z): with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): logdet = 0. for module in reversed(self.modules): z, ldet = module.inverse(z) logdet = logdet + ldet return z, logdet class Reverse(BaseTransform): def __init__(self, hps, name): name = f'reverse_{name}' super(Reverse, self).__init__(hps, name) def forward(self, x): z = tf.reverse(x, [-1]) ldet = 0.0 return z, ldet def inverse(self, z): x = tf.reverse(z, [-1]) ldet = 0.0 return x, ldet class LeakyReLU(BaseTransform): def __init__(self, hps, name): name = f'lrelu_{name}' super(LeakyReLU, self).__init__(hps, name) def build(self): with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): self.alpha = tf.nn.sigmoid( tf.compat.v1.get_variable('log_alpha', initializer=5.0, dtype=tf.float32)) def forward(self, x): num_negative = tf.reduce_sum(input_tensor=tf.cast(tf.less(x, 0.0), tf.float32), axis=1) ldet = num_negative * tf.math.log(self.alpha) z = tf.maximum(x, self.alpha * x) return z, ldet def inverse(self, z): num_negative = tf.reduce_sum(input_tensor=tf.cast(tf.less(z, 0.0), tf.float32), axis=1) ldet = -1. * num_negative * tf.math.log(self.alpha) x = tf.minimum(z, z / self.alpha) return x, ldet class Coupling(BaseTransform): def __init__(self, hps, name): name = f'coupling_{name}' super(Coupling, self).__init__(hps, name) def build(self): d = self.hps.dimension with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): self.net1 = tfk.Sequential(name=f'{self.name}/ms1') for i, h in enumerate(self.hps.coupling_hids): self.net1.add(tfk.layers.Dense(h, activation=tf.nn.tanh, name=f'l{i}')) self.net1.add(tfk.layers.Dense(d, name=f'l{i+1}', kernel_initializer=tf.compat.v1.zeros_initializer())) self.net2 = tfk.Sequential(name=f'{self.name}/ms2') for i, h in enumerate(self.hps.coupling_hids): self.net2.add(tfk.layers.Dense(h, activation=tf.nn.tanh, name=f'l{i}')) self.net2.add(tfk.layers.Dense(d, name=f'l{i+1}', kernel_initializer=tf.compat.v1.zeros_initializer())) def forward(self, x): B = tf.shape(input=x)[0] d = self.hps.dimension ldet = tf.zeros(B, dtype=tf.float32) # part 1 inp, out = x[:,::2], x[:,1::2] scale, shift = tf.split(self.net1(inp), 2, axis=1) out = (out + shift) * tf.exp(scale) x = tf.reshape(tf.stack([inp,out],axis=-1), [B,d]) ldet = ldet + tf.reduce_sum(input_tensor=scale, axis=1) # part 2 out, inp = x[:,::2], x[:,1::2] scale, shift = tf.split(self.net2(inp), 2, axis=1) out = (out + shift) * tf.exp(scale) x = tf.reshape(tf.stack([out, inp],axis=-1), [B,d]) ldet = ldet + tf.reduce_sum(input_tensor=scale, axis=1) return x, ldet def inverse(self, z): B = tf.shape(input=z)[0] d = self.hps.dimension ldet = tf.zeros(B, dtype=tf.float32) # part 2 out, inp = z[:,::2], z[:,1::2] scale, shift = tf.split(self.net2(inp), 2, axis=1) out = out * tf.exp(-scale) - shift z = tf.reshape(tf.stack([out, inp],axis=-1), [B,d]) ldet = ldet - tf.reduce_sum(input_tensor=scale, axis=1) # part 1 inp, out = z[:,::2], z[:,1::2] scale, shift = tf.split(self.net1(inp), 2, axis=1) out = out * tf.exp(-scale) - shift z = tf.reshape(tf.stack([inp, out], axis=-1), [B,d]) ldet = ldet - tf.reduce_sum(input_tensor=scale, axis=1) return z, ldet class LULinear(BaseTransform): def __init__(self, hps, name): name = f'linear_{name}' super(LULinear, self).__init__(hps, name) def build(self): d = self.hps.dimension with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE): np_w = np.eye(d).astype("float32") self.w = tf.compat.v1.get_variable('W', initializer=np_w) self.b = tf.compat.v1.get_variable('b', initializer=tf.zeros([d])) def get_LU(self): d = self.hps.dimension W = self.w U = tf.linalg.band_part(W, 0, -1) L = tf.eye(d) + W - U A = tf.matmul(L, U) return A, L, U def forward(self, x): A, L, U = self.get_LU() ldet = tf.reduce_sum(input_tensor=tf.math.log(tf.abs(tf.linalg.diag_part(U)))) z = tf.matmul(x, A) + self.b return z, ldet def inverse(self, z): B = tf.shape(input=z)[0] A, L, U = self.get_LU() ldet = -1 * tf.reduce_sum(input_tensor=tf.math.log(tf.abs(tf.linalg.diag_part(U)))) Ut = tf.tile(tf.expand_dims(tf.transpose(a=U, perm=[1, 0]), axis=0), [B,1,1]) Lt = tf.tile(tf.expand_dims(tf.transpose(a=L, perm=[1, 0]), axis=0), [B,1,1]) zt = tf.expand_dims(z - self.b, -1) sol = tf.linalg.triangular_solve(Ut, zt) x = tf.linalg.triangular_solve(Lt, sol, lower=False) x = tf.squeeze(x, axis=-1) return x, ldet # register all modules TRANS = { 'CP': Coupling, 'R': Reverse, 'LR': LeakyReLU, 'L': LULinear, } if __name__ == '__main__': from pprint import pformat from easydict import EasyDict as edict hps = edict() hps.dimension = 8 hps.coupling_hids = [32,32] hps.transform = ['L','LR','CP','R'] x_ph = tf.compat.v1.placeholder(tf.float32, [32,8]) l1 = Transform(hps, '1') l2 = Transform(hps, '2') z, fdet1 = l1.forward(x_ph) z, fdet2 = l2.forward(z) fdet = fdet1 + fdet2 x, bdet2 = l2.inverse(z) x, bdet1 = l1.inverse(x) bdet = bdet1 + bdet2 err = tf.reduce_sum(input_tensor=tf.square(x_ph - x)) det = tf.reduce_sum(input_tensor=fdet + bdet) loss = tf.reduce_sum(input_tensor=tf.square(z)) - fdet train_op = tf.compat.v1.train.AdamOptimizer(0.0001).minimize(loss) sess = tf.compat.v1.Session() sess.run(tf.compat.v1.global_variables_initializer()) print('='*20) print('Variables:') print(pformat(tf.compat.v1.trainable_variables())) for i in range(1000): x_nda = np.random.randn(32,8) feed_dict = {x_ph:x_nda} res = sess.run([err,det], feed_dict) print(f'err:{res[0]} det:{res[1]}') sess.run(train_op, feed_dict)