kevinwang676's picture
Upload 93 files
9016314 verified
raw
history blame
7.79 kB
import tensorflow as tf
tfk = tf.keras
import numpy as np
# base class
class BaseTransform(object):
def __init__(self, hps, name='base'):
self.name = name
self.hps = hps
self.build()
def build(self):
pass
def forward(self, x):
raise NotImplementedError()
def inverse(self, z):
raise NotImplementedError()
class Transform(BaseTransform):
def __init__(self, hps, name='transform'):
super(Transform, self).__init__(hps, name)
def build(self):
self.modules = []
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
for i, name in enumerate(self.hps.transform):
m = TRANS[name](self.hps, f'{i}')
self.modules.append(m)
def forward(self, x):
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
logdet = 0.
for module in self.modules:
x, ldet = module.forward(x)
logdet = logdet + ldet
return x, logdet
def inverse(self, z):
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
logdet = 0.
for module in reversed(self.modules):
z, ldet = module.inverse(z)
logdet = logdet + ldet
return z, logdet
class Reverse(BaseTransform):
def __init__(self, hps, name):
name = f'reverse_{name}'
super(Reverse, self).__init__(hps, name)
def forward(self, x):
z = tf.reverse(x, [-1])
ldet = 0.0
return z, ldet
def inverse(self, z):
x = tf.reverse(z, [-1])
ldet = 0.0
return x, ldet
class LeakyReLU(BaseTransform):
def __init__(self, hps, name):
name = f'lrelu_{name}'
super(LeakyReLU, self).__init__(hps, name)
def build(self):
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
self.alpha = tf.nn.sigmoid(
tf.compat.v1.get_variable('log_alpha',
initializer=5.0,
dtype=tf.float32))
def forward(self, x):
num_negative = tf.reduce_sum(input_tensor=tf.cast(tf.less(x, 0.0), tf.float32), axis=1)
ldet = num_negative * tf.math.log(self.alpha)
z = tf.maximum(x, self.alpha * x)
return z, ldet
def inverse(self, z):
num_negative = tf.reduce_sum(input_tensor=tf.cast(tf.less(z, 0.0), tf.float32), axis=1)
ldet = -1. * num_negative * tf.math.log(self.alpha)
x = tf.minimum(z, z / self.alpha)
return x, ldet
class Coupling(BaseTransform):
def __init__(self, hps, name):
name = f'coupling_{name}'
super(Coupling, self).__init__(hps, name)
def build(self):
d = self.hps.dimension
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
self.net1 = tfk.Sequential(name=f'{self.name}/ms1')
for i, h in enumerate(self.hps.coupling_hids):
self.net1.add(tfk.layers.Dense(h, activation=tf.nn.tanh, name=f'l{i}'))
self.net1.add(tfk.layers.Dense(d, name=f'l{i+1}', kernel_initializer=tf.compat.v1.zeros_initializer()))
self.net2 = tfk.Sequential(name=f'{self.name}/ms2')
for i, h in enumerate(self.hps.coupling_hids):
self.net2.add(tfk.layers.Dense(h, activation=tf.nn.tanh, name=f'l{i}'))
self.net2.add(tfk.layers.Dense(d, name=f'l{i+1}', kernel_initializer=tf.compat.v1.zeros_initializer()))
def forward(self, x):
B = tf.shape(input=x)[0]
d = self.hps.dimension
ldet = tf.zeros(B, dtype=tf.float32)
# part 1
inp, out = x[:,::2], x[:,1::2]
scale, shift = tf.split(self.net1(inp), 2, axis=1)
out = (out + shift) * tf.exp(scale)
x = tf.reshape(tf.stack([inp,out],axis=-1), [B,d])
ldet = ldet + tf.reduce_sum(input_tensor=scale, axis=1)
# part 2
out, inp = x[:,::2], x[:,1::2]
scale, shift = tf.split(self.net2(inp), 2, axis=1)
out = (out + shift) * tf.exp(scale)
x = tf.reshape(tf.stack([out, inp],axis=-1), [B,d])
ldet = ldet + tf.reduce_sum(input_tensor=scale, axis=1)
return x, ldet
def inverse(self, z):
B = tf.shape(input=z)[0]
d = self.hps.dimension
ldet = tf.zeros(B, dtype=tf.float32)
# part 2
out, inp = z[:,::2], z[:,1::2]
scale, shift = tf.split(self.net2(inp), 2, axis=1)
out = out * tf.exp(-scale) - shift
z = tf.reshape(tf.stack([out, inp],axis=-1), [B,d])
ldet = ldet - tf.reduce_sum(input_tensor=scale, axis=1)
# part 1
inp, out = z[:,::2], z[:,1::2]
scale, shift = tf.split(self.net1(inp), 2, axis=1)
out = out * tf.exp(-scale) - shift
z = tf.reshape(tf.stack([inp, out], axis=-1), [B,d])
ldet = ldet - tf.reduce_sum(input_tensor=scale, axis=1)
return z, ldet
class LULinear(BaseTransform):
def __init__(self, hps, name):
name = f'linear_{name}'
super(LULinear, self).__init__(hps, name)
def build(self):
d = self.hps.dimension
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
np_w = np.eye(d).astype("float32")
self.w = tf.compat.v1.get_variable('W', initializer=np_w)
self.b = tf.compat.v1.get_variable('b', initializer=tf.zeros([d]))
def get_LU(self):
d = self.hps.dimension
W = self.w
U = tf.linalg.band_part(W, 0, -1)
L = tf.eye(d) + W - U
A = tf.matmul(L, U)
return A, L, U
def forward(self, x):
A, L, U = self.get_LU()
ldet = tf.reduce_sum(input_tensor=tf.math.log(tf.abs(tf.linalg.diag_part(U))))
z = tf.matmul(x, A) + self.b
return z, ldet
def inverse(self, z):
B = tf.shape(input=z)[0]
A, L, U = self.get_LU()
ldet = -1 * tf.reduce_sum(input_tensor=tf.math.log(tf.abs(tf.linalg.diag_part(U))))
Ut = tf.tile(tf.expand_dims(tf.transpose(a=U, perm=[1, 0]), axis=0), [B,1,1])
Lt = tf.tile(tf.expand_dims(tf.transpose(a=L, perm=[1, 0]), axis=0), [B,1,1])
zt = tf.expand_dims(z - self.b, -1)
sol = tf.linalg.triangular_solve(Ut, zt)
x = tf.linalg.triangular_solve(Lt, sol, lower=False)
x = tf.squeeze(x, axis=-1)
return x, ldet
# register all modules
TRANS = {
'CP': Coupling,
'R': Reverse,
'LR': LeakyReLU,
'L': LULinear,
}
if __name__ == '__main__':
from pprint import pformat
from easydict import EasyDict as edict
hps = edict()
hps.dimension = 8
hps.coupling_hids = [32,32]
hps.transform = ['L','LR','CP','R']
x_ph = tf.compat.v1.placeholder(tf.float32, [32,8])
l1 = Transform(hps, '1')
l2 = Transform(hps, '2')
z, fdet1 = l1.forward(x_ph)
z, fdet2 = l2.forward(z)
fdet = fdet1 + fdet2
x, bdet2 = l2.inverse(z)
x, bdet1 = l1.inverse(x)
bdet = bdet1 + bdet2
err = tf.reduce_sum(input_tensor=tf.square(x_ph - x))
det = tf.reduce_sum(input_tensor=fdet + bdet)
loss = tf.reduce_sum(input_tensor=tf.square(z)) - fdet
train_op = tf.compat.v1.train.AdamOptimizer(0.0001).minimize(loss)
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
print('='*20)
print('Variables:')
print(pformat(tf.compat.v1.trainable_variables()))
for i in range(1000):
x_nda = np.random.randn(32,8)
feed_dict = {x_ph:x_nda}
res = sess.run([err,det], feed_dict)
print(f'err:{res[0]} det:{res[1]}')
sess.run(train_op, feed_dict)