NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sonnet as snt
import tensorflow as tf
import numpy as np
import math
SQUARED_OBSERVATION = "squared"
ABS_OBSERVATION = "abs"
STANDARD_OBSERVATION = "standard"
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION]
ROUND_TRANSITION = "round"
STANDARD_TRANSITION = "standard"
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION]
class Q(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(
{"w": tf.zeros_initializer, "b": tf.zeros_initializer})
else:
initializers.append(
{"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer})
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus = [
snt.Linear(output_size=state_size,
initializers=initializers[t],
name="q_mu_%d" % t,
custom_getter=custom_getter
)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1))
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
if t != 0:
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0])
class PreviousStateQ(Q):
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](prev_state)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0])
class ObservationQ(Q):
def q_zt(self, observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = self.mus[t](observation)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/b_%d" % t, f.b[0])
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0])
class SimpleMeanQ(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
init_mu0_to_zero=False,
graph_collection_name="Q_VARS"):
self.sigma_min = sigma_min
self.dtype = dtype
self.graph_collection_name = graph_collection_name
initializers = []
for t in xrange(num_timesteps):
if t == 0 and init_mu0_to_zero:
initializers.append(tf.zeros_initializer)
else:
initializers.append(tf.random_uniform_initializer(seed=random_seed))
self.mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_mu_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=initializers[t])
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(num_timesteps)
]
def q_zt(self, unused_observation, prev_state, t):
batch_size = tf.shape(prev_state)[0]
q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1])
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def summarize_weights(self):
for t, sigma in enumerate(self.sigmas):
tf.summary.scalar("q_sigma/%d" % t, sigma[0])
for t, f in enumerate(self.mus):
tf.summary.scalar("q_mu/%d" % t, f[0])
class R(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
sigma_init=1.,
random_seed=None,
graph_collection_name="R_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.mus= [
snt.Linear(output_size=state_size,
initializers=initializers,
name="r_mu_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(sigma_init))
for t in xrange(num_timesteps)
]
def r_xn(self, z_t, t):
batch_size = tf.shape(z_t)[0]
r_mu = self.mus[t](z_t)
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
for t in range(len(self.mus) - 1):
tf.summary.scalar("r_mu/%d" % t, self.mus[t][0])
tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0])
class P(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.graph_collection_name = graph_collection_name
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="p_b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
else: # p(z_0) is Normal(0,1)
z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class ShortChainNonlinearP(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
transition_dist=tf.contrib.distributions.Normal,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.transition_type = transition_type
self.transition_dist = transition_dist
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = self.transition_dist(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, unused_obs, z_ni):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(z_ni)
elif self.transition_type == STANDARD_TRANSITION:
loc = z_ni
generative_sigma_sq = tf.ones_like(loc) * self.observation_variance
return self.transition_dist(
loc=loc, scale=tf.sqrt(generative_sigma_sq))
class BimodalPriorP(object):
def __init__(self,
state_size,
num_timesteps,
mixing_coeff=0.5,
prior_mode_mean=1,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
trainable=True,
init_bs_to_zero=False,
graph_collection_name="P_VARS"):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.mixing_coeff = mixing_coeff
self.prior_mode_mean = prior_mode_mean
if init_bs_to_zero:
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)]
else:
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)]
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=initializers[t],
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name],
trainable=trainable) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
def posterior(self, observation, prev_state, t):
# NOTE: This is currently wrong, but would require a refactoring of
# summarize_q to fix as kl is not defined for a mixture
"""Computes the true posterior p(z_t|z_{t-1}, z_n)."""
# bs[0] is really b_1
# Bs[i] is sum from k=i+1^n b_k
mu = observation - self.Bs[t]
if t > 0:
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t)
mu /= float(self.num_timesteps - t + 1)
sigma = tf.ones_like(mu) * self.variance * (
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1))
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def lookahead(self, state, t):
"""Computes the true lookahead distribution p(z_n|z_t)."""
mu = state + self.Bs[t]
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t)
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1)
mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs
mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs
zn_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(sigma))
zn_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(sigma))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
zn_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[zn_pos, zn_neg],
validate_args=True)
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
z_mu_p = prev_state + self.bs[t - 1]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance))
return p_zt
else: # p(z_0) is mixture of two Normals
mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean
mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean
z0_pos = tf.contrib.distributions.Normal(
loc=mu_pos,
scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance))
z0_neg = tf.contrib.distributions.Normal(
loc=mu_neg,
scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance))
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64)
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1])
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs)
z0_dist = tf.contrib.distributions.Mixture(
cat=mode_selection_dist,
components=[z0_pos, z0_neg],
validate_args=False)
return z0_dist
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance))
class Model(object):
def __init__(self,
p,
q,
r,
state_size,
num_timesteps,
dtype=tf.float32):
self.p = p
self.q = q
self.r = r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q
zt = q_zt.sample()
r_xn = self.r.r_xn(zt, t)
# Calculate the logprobs and sum over the state size.
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
r_sigma_init=1,
variance=1.0,
mixing_coeff=0.5,
prior_mode_mean=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r = R(state_size,
num_timesteps,
sigma_min=sigma_min,
sigma_init=r_sigma_init,
dtype=dtype,
random_seed=random_seed)
model = Model(p, q, r, state_size, num_timesteps, dtype=dtype)
return model
class BackwardsModel(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32):
self.state_size = state_size
self.num_timesteps = num_timesteps
self.sigma_min = sigma_min
self.dtype = dtype
self.bs = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="b_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0)
self.q_mus = [
snt.Linear(output_size=state_size) for _ in xrange(num_timesteps)
]
self.q_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_mus = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_mu_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
self.r_sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
initializer=tf.zeros_initializer) for t in xrange(num_timesteps)
]
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def posterior(self, unused_observation, prev_state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state))
def lookahead(self, state, unused_t):
# TODO(dieterichl): Correct this.
return tf.contrib.distributions.Normal(
loc=tf.zeros_like(state), scale=tf.zeros_like(state))
def q_zt(self, observation, next_state, t):
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n)."""
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(next_state)[0]
q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1))
q_sigma = tf.maximum(
tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
return q_zt
def p_zt(self, prev_state, t):
"""Computes the model p(z_{t+1}| z_{t})."""
t_backwards = self.num_timesteps - t - 1
z_mu_p = prev_state + self.bs[t_backwards]
p_zt = tf.contrib.distributions.Normal(
loc=z_mu_p, scale=tf.ones_like(z_mu_p))
return p_zt
def generative(self, unused_observation, z_nm1):
"""Computes the model's generative distribution p(z_n| z_{n-1})."""
generative_p_mu = z_nm1 + self.bs[-1]
return tf.contrib.distributions.Normal(
loc=generative_p_mu, scale=tf.ones_like(generative_p_mu))
def r(self, z_t, t):
t_backwards = self.num_timesteps - t - 1
batch_size = tf.shape(z_t)[0]
r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1])
r_sigma = tf.maximum(
tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min)
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1])
return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma))
def likelihood(self, observation):
batch_size = tf.shape(observation)[0]
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1])
sigma = tf.ones_like(mu) * (self.num_timesteps + 1)
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma))
# Average over the batch and take the sum over the state size
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1))
def __call__(self, next_state, observation, t):
# next state = z_{t+1}
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}).
q_zt = self.q_zt(observation, next_state, t)
# sample from q
zt = q_zt.sample()
# Compute the p distribution over z, p(z_{t+1}|z_{t}).
p_zt = self.p_zt(zt, t)
# Compute log p(z_{t+1} | z_t)
if t == 0:
log_p_zt = p_zt.log_prob(observation)
else:
log_p_zt = p_zt.log_prob(next_state)
# Compute r prior over zt
r_zt = self.r(zt, t)
log_r_zt = r_zt.log_prob(zt)
# Compute proposal density at zt
log_q_zt = q_zt.log_prob(zt)
# If we're at the last timestep, also calc the logprob of the observation.
if t == self.num_timesteps - 1:
p_z0_dist = tf.contrib.distributions.Normal(
loc=tf.zeros_like(zt), scale=tf.ones_like(zt))
z0_log_prob = p_z0_dist.log_prob(zt)
else:
z0_log_prob = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt)
class LongChainP(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.sigma_min = sigma_min
self.dtype = dtype
self.variance = variance
self.observation_variance = observation_variance
self.observation_type = observation_type
self.transition_type = transition_type
def likelihood(self, observations):
"""Computes the model's true likelihood of the observations.
Args:
observations: A [batch_size, m, state_size] Tensor representing each of
the m observations.
Returns:
logprob: The true likelihood of the observations given the model.
"""
raise ValueError("Likelihood is not defined for long-chain models")
# batch_size = tf.shape(observations)[0]
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype)
# sigma = np.fromfunction(
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1),
# [self.num_obs, self.num_obs])
# sigma += np.eye(self.num_obs)
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype)
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...],
# [batch_size, self.state_size, 1, 1])
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance(
# loc=mu,
# covariance_matrix=sigma)
# Average over the batch and take the sum over the state size
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1))
def p_zt(self, prev_state, t):
"""Computes the model p(z_t| z_{t-1})."""
batch_size = tf.shape(prev_state)[0]
if t > 0:
if self.transition_type == ROUND_TRANSITION:
loc = tf.round(prev_state)
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance))
elif self.transition_type == STANDARD_TRANSITION:
loc = prev_state
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance))
else: # p(z_0) is Normal(0,1)
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype)
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance)
p_zt = tf.contrib.distributions.Normal(
loc=loc,
scale=tf.sqrt(tf.ones_like(loc) * self.variance))
return p_zt
def generative(self, z_ni, t):
"""Computes the model's generative distribution p(x_i| z_{ni})."""
if self.observation_type == SQUARED_OBSERVATION:
generative_mu = tf.square(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == ABS_OBSERVATION:
generative_mu = tf.abs(z_ni)
tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance))
elif self.observation_type == STANDARD_OBSERVATION:
generative_mu = z_ni
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance))
generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance
return tf.contrib.distributions.Normal(
loc=generative_mu, scale=tf.sqrt(generative_sigma_sq))
class LongChainQ(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.sigma_min = sigma_min
self.dtype = dtype
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs +1
initializers = {
"w": tf.random_uniform_initializer(seed=random_seed),
"b": tf.zeros_initializer
}
self.mus = [
snt.Linear(output_size=state_size, initializers=initializers)
for t in xrange(self.num_timesteps)
]
self.sigmas = [
tf.get_variable(
shape=[state_size],
dtype=self.dtype,
name="q_sigma_%d" % (t + 1),
initializer=tf.random_uniform_initializer(seed=random_seed))
for t in xrange(self.num_timesteps)
]
def first_relevant_obs_index(self, t):
return int(max((t-1)/self.steps_per_obs, 0))
def q_zt(self, observations, prev_state, t):
"""Computes a distribution over z_t.
Args:
observations: a [batch_size, num_observations, state_size] Tensor.
prev_state: a [batch_size, state_size] Tensor.
t: The current timestep, an int Tensor.
"""
# filter out unneeded past obs
first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs))
num_relevant_observations = self.num_obs - first_relevant_obs_index
observations = observations[:,first_relevant_obs_index:,:]
batch_size = tf.shape(prev_state)[0]
# concatenate the prev state and observations along the second axis (that is
# not the batch or state size axis, and then flatten it to
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into
# the linear layer.
q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1)
q_input = tf.reshape(q_input,
[batch_size, (num_relevant_observations + 1) * self.state_size])
q_mu = self.mus[t](q_input)
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1])
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma))
tf.logging.info(
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format(
**{"t": t,
"tm1": t-1,
"obsf": (first_relevant_obs_index+1)*self.steps_per_obs,
"obst":self.steps_per_obs*self.num_obs}))
return q_zt
def summarize_weights(self):
pass
class LongChainR(object):
def __init__(self,
state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None):
self.state_size = state_size
self.dtype = dtype
self.sigma_min = sigma_min
self.steps_per_obs = steps_per_obs
self.num_obs = num_obs
self.num_timesteps = num_obs*steps_per_obs + 1
self.sigmas = [
tf.get_variable(
shape=[self.num_future_obs(t)],
dtype=self.dtype,
name="r_sigma_%d" % (t + 1),
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100))
initializer=tf.constant_initializer(1.0))
for t in range(self.num_timesteps)
]
def first_future_obs_index(self, t):
return int(math.floor(t / self.steps_per_obs))
def num_future_obs(self, t):
return int(self.num_obs - self.first_future_obs_index(t))
def r_xn(self, z_t, t):
"""Computes a distribution over the future observations given current latent
state.
The indexing in these messages is 1 indexed and inclusive. This is
consistent with the latex documents.
Args:
z_t: [batch_size, state_size] Tensor
t: Current timestep
"""
tf.logging.info(
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format(
**{"t": t,
"start": (self.first_future_obs_index(t)+1)*self.steps_per_obs,
"end": self.num_timesteps-1}))
batch_size = tf.shape(z_t)[0]
# the mean for all future observations is the same.
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor
r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1])
# compute the variance
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min)
# the variance is the same across all state dimensions, so we only have to
# time sigma to be [batch_size, num_future_obs].
r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size])
return tf.contrib.distributions.Normal(
loc=r_mu, scale=tf.sqrt(r_sigma))
def summarize_weights(self):
pass
class LongChainModel(object):
def __init__(self,
p,
q,
r,
state_size,
num_obs,
steps_per_obs,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r = r
self.disable_r = disable_r
self.state_size = state_size
self.num_obs = num_obs
self.steps_per_obs = steps_per_obs
self.num_timesteps = steps_per_obs*num_obs + 1
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def next_obs_ind(self, t):
return int(math.floor(max(t-1,0)/self.steps_per_obs))
def __call__(self, prev_state, observations, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observations, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_xn = self.r.r_xn(zt, t)
log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :])
# sum over state size and observation, leaving the batch index
log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2])
else:
log_r_xn = tf.zeros_like(log_p_zt)
if t != 0 and t % self.steps_per_obs == 0:
generative_dist = self.p.generative(zt, t)
log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:])
log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn)
@staticmethod
def create(state_size,
num_obs,
steps_per_obs,
sigma_min=1e-5,
variance=1.0,
observation_variance=1.0,
observation_type=STANDARD_OBSERVATION,
transition_type=STANDARD_TRANSITION,
dtype=tf.float32,
random_seed=None,
disable_r=False):
p = LongChainP(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
observation_type=observation_type,
transition_type=transition_type,
dtype=dtype,
random_seed=random_seed)
q = LongChainQ(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
r = LongChainR(
state_size,
num_obs,
steps_per_obs,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = LongChainModel(
p, q, r, state_size, num_obs, steps_per_obs,
dtype=dtype,
disable_r=disable_r)
return model
class RTilde(object):
def __init__(self,
state_size,
num_timesteps,
sigma_min=1e-5,
dtype=tf.float32,
random_seed=None,
graph_collection_name="R_TILDE_VARS"):
self.dtype = dtype
self.sigma_min = sigma_min
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed),
"b": tf.zeros_initializer}
self.graph_collection_name=graph_collection_name
def custom_getter(getter, *args, **kwargs):
out = getter(*args, **kwargs)
ref = tf.get_collection_ref(self.graph_collection_name)
if out not in ref:
ref.append(out)
return out
self.fns = [
snt.Linear(output_size=2*state_size,
initializers=initializers,
name="r_tilde_%d" % t,
custom_getter=custom_getter)
for t in xrange(num_timesteps)
]
def r_zt(self, z_t, observation, t):
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1)))
out = self.fns[t](tf.concat([z_t, observation], axis=1))
mu, raw_sigma_sq = tf.split(out, 2, axis=1)
sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min)
return mu, sigma_sq
class TDModel(object):
def __init__(self,
p,
q,
r_tilde,
state_size,
num_timesteps,
dtype=tf.float32,
disable_r=False):
self.p = p
self.q = q
self.r_tilde = r_tilde
self.disable_r = disable_r
self.state_size = state_size
self.num_timesteps = num_timesteps
self.dtype = dtype
def zero_state(self, batch_size):
return tf.zeros([batch_size, self.state_size], dtype=self.dtype)
def __call__(self, prev_state, observation, t):
"""Computes the importance weight for the model system.
Args:
prev_state: [batch_size, state_size] Tensor
observations: [batch_size, num_observations, state_size] Tensor
"""
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}).
q_zt = self.q.q_zt(observation, prev_state, t)
# Compute the p distribution over z, p(z_t|z_{t-1}).
p_zt = self.p.p_zt(prev_state, t)
# sample from q and evaluate the logprobs, summing over the state size
zt = q_zt.sample()
# If it isn't the last timestep, compute the distribution over the next z.
if t < self.num_timesteps - 1:
p_ztplus1 = self.p.p_zt(zt, t+1)
else:
p_ztplus1 = None
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1)
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1)
if not self.disable_r and t < self.num_timesteps-1:
# score the remaining observations using r
r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1)
else:
r_tilde_mu = None
r_tilde_sigma_sq = None
if t == self.num_timesteps - 1:
generative_dist = self.p.generative(observation, zt)
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1)
else:
log_p_x_given_z = tf.zeros_like(log_q_zt)
return (zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1)
@staticmethod
def create(state_size,
num_timesteps,
sigma_min=1e-5,
variance=1.0,
dtype=tf.float32,
random_seed=None,
train_p=True,
p_type="unimodal",
q_type="normal",
mixing_coeff=0.5,
prior_mode_mean=1.0,
observation_variance=1.0,
transition_type=STANDARD_TRANSITION,
use_bs=True):
if p_type == "unimodal":
p = P(state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif p_type == "bimodal":
p = BimodalPriorP(
state_size,
num_timesteps,
mixing_coeff=mixing_coeff,
prior_mode_mean=prior_mode_mean,
sigma_min=sigma_min,
variance=variance,
dtype=dtype,
random_seed=random_seed,
trainable=train_p,
init_bs_to_zero=not use_bs)
elif "nonlinear" in p_type:
if "cauchy" in p_type:
trans_dist = tf.contrib.distributions.Cauchy
else:
trans_dist = tf.contrib.distributions.Normal
p = ShortChainNonlinearP(
state_size,
num_timesteps,
sigma_min=sigma_min,
variance=variance,
observation_variance=observation_variance,
transition_type=transition_type,
transition_dist=trans_dist,
dtype=dtype,
random_seed=random_seed
)
if q_type == "normal":
q_class = Q
elif q_type == "simple_mean":
q_class = SimpleMeanQ
elif q_type == "prev_state":
q_class = PreviousStateQ
elif q_type == "observation":
q_class = ObservationQ
q = q_class(state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed,
init_mu0_to_zero=not use_bs)
r_tilde = RTilde(
state_size,
num_timesteps,
sigma_min=sigma_min,
dtype=dtype,
random_seed=random_seed)
model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype)
return model