|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import functools |
|
import sonnet as snt |
|
import tensorflow as tf |
|
import numpy as np |
|
import math |
|
|
|
SQUARED_OBSERVATION = "squared" |
|
ABS_OBSERVATION = "abs" |
|
STANDARD_OBSERVATION = "standard" |
|
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION] |
|
|
|
ROUND_TRANSITION = "round" |
|
STANDARD_TRANSITION = "standard" |
|
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION] |
|
|
|
|
|
class Q(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
init_mu0_to_zero=False, |
|
graph_collection_name="Q_VARS"): |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.graph_collection_name = graph_collection_name |
|
initializers = [] |
|
for t in xrange(num_timesteps): |
|
if t == 0 and init_mu0_to_zero: |
|
initializers.append( |
|
{"w": tf.zeros_initializer, "b": tf.zeros_initializer}) |
|
else: |
|
initializers.append( |
|
{"w": tf.random_uniform_initializer(seed=random_seed), |
|
"b": tf.zeros_initializer}) |
|
|
|
def custom_getter(getter, *args, **kwargs): |
|
out = getter(*args, **kwargs) |
|
ref = tf.get_collection_ref(self.graph_collection_name) |
|
if out not in ref: |
|
ref.append(out) |
|
return out |
|
|
|
self.mus = [ |
|
snt.Linear(output_size=state_size, |
|
initializers=initializers[t], |
|
name="q_mu_%d" % t, |
|
custom_getter=custom_getter |
|
) |
|
for t in xrange(num_timesteps) |
|
] |
|
self.sigmas = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="q_sigma_%d" % (t + 1), |
|
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], |
|
initializer=tf.random_uniform_initializer(seed=random_seed)) |
|
for t in xrange(num_timesteps) |
|
] |
|
|
|
def q_zt(self, observation, prev_state, t): |
|
batch_size = tf.shape(prev_state)[0] |
|
q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1)) |
|
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) |
|
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) |
|
return q_zt |
|
|
|
def summarize_weights(self): |
|
for t, sigma in enumerate(self.sigmas): |
|
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) |
|
for t, f in enumerate(self.mus): |
|
tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) |
|
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0]) |
|
if t != 0: |
|
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0]) |
|
|
|
|
|
class PreviousStateQ(Q): |
|
|
|
def q_zt(self, unused_observation, prev_state, t): |
|
batch_size = tf.shape(prev_state)[0] |
|
q_mu = self.mus[t](prev_state) |
|
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) |
|
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) |
|
return q_zt |
|
|
|
def summarize_weights(self): |
|
for t, sigma in enumerate(self.sigmas): |
|
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) |
|
for t, f in enumerate(self.mus): |
|
tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) |
|
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0]) |
|
|
|
|
|
class ObservationQ(Q): |
|
|
|
def q_zt(self, observation, prev_state, t): |
|
batch_size = tf.shape(prev_state)[0] |
|
q_mu = self.mus[t](observation) |
|
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) |
|
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) |
|
return q_zt |
|
|
|
def summarize_weights(self): |
|
for t, sigma in enumerate(self.sigmas): |
|
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) |
|
for t, f in enumerate(self.mus): |
|
tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) |
|
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0]) |
|
|
|
|
|
class SimpleMeanQ(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
init_mu0_to_zero=False, |
|
graph_collection_name="Q_VARS"): |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.graph_collection_name = graph_collection_name |
|
initializers = [] |
|
for t in xrange(num_timesteps): |
|
if t == 0 and init_mu0_to_zero: |
|
initializers.append(tf.zeros_initializer) |
|
else: |
|
initializers.append(tf.random_uniform_initializer(seed=random_seed)) |
|
|
|
self.mus = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="q_mu_%d" % (t + 1), |
|
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], |
|
initializer=initializers[t]) |
|
for t in xrange(num_timesteps) |
|
] |
|
self.sigmas = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="q_sigma_%d" % (t + 1), |
|
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], |
|
initializer=tf.random_uniform_initializer(seed=random_seed)) |
|
for t in xrange(num_timesteps) |
|
] |
|
|
|
def q_zt(self, unused_observation, prev_state, t): |
|
batch_size = tf.shape(prev_state)[0] |
|
q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1]) |
|
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) |
|
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) |
|
return q_zt |
|
|
|
def summarize_weights(self): |
|
for t, sigma in enumerate(self.sigmas): |
|
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) |
|
for t, f in enumerate(self.mus): |
|
tf.summary.scalar("q_mu/%d" % t, f[0]) |
|
|
|
|
|
class R(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
dtype=tf.float32, |
|
sigma_init=1., |
|
random_seed=None, |
|
graph_collection_name="R_VARS"): |
|
self.dtype = dtype |
|
self.sigma_min = sigma_min |
|
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), |
|
"b": tf.zeros_initializer} |
|
self.graph_collection_name=graph_collection_name |
|
|
|
def custom_getter(getter, *args, **kwargs): |
|
out = getter(*args, **kwargs) |
|
ref = tf.get_collection_ref(self.graph_collection_name) |
|
if out not in ref: |
|
ref.append(out) |
|
return out |
|
|
|
self.mus= [ |
|
snt.Linear(output_size=state_size, |
|
initializers=initializers, |
|
name="r_mu_%d" % t, |
|
custom_getter=custom_getter) |
|
for t in xrange(num_timesteps) |
|
] |
|
|
|
self.sigmas = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="r_sigma_%d" % (t + 1), |
|
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], |
|
|
|
initializer=tf.constant_initializer(sigma_init)) |
|
for t in xrange(num_timesteps) |
|
] |
|
|
|
def r_xn(self, z_t, t): |
|
batch_size = tf.shape(z_t)[0] |
|
r_mu = self.mus[t](z_t) |
|
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1]) |
|
return tf.contrib.distributions.Normal( |
|
loc=r_mu, scale=tf.sqrt(r_sigma)) |
|
|
|
def summarize_weights(self): |
|
for t in range(len(self.mus) - 1): |
|
tf.summary.scalar("r_mu/%d" % t, self.mus[t][0]) |
|
tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0]) |
|
|
|
|
|
class P(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
variance=1.0, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
trainable=True, |
|
init_bs_to_zero=False, |
|
graph_collection_name="P_VARS"): |
|
self.state_size = state_size |
|
self.num_timesteps = num_timesteps |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.variance = variance |
|
self.graph_collection_name = graph_collection_name |
|
if init_bs_to_zero: |
|
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)] |
|
else: |
|
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)] |
|
|
|
self.bs = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="p_b_%d" % (t + 1), |
|
initializer=initializers[t], |
|
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], |
|
trainable=trainable) for t in xrange(num_timesteps) |
|
] |
|
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) |
|
|
|
def posterior(self, observation, prev_state, t): |
|
"""Computes the true posterior p(z_t|z_{t-1}, z_n).""" |
|
|
|
|
|
mu = observation - self.Bs[t] |
|
if t > 0: |
|
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t) |
|
mu /= float(self.num_timesteps - t + 1) |
|
sigma = tf.ones_like(mu) * self.variance * ( |
|
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1)) |
|
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) |
|
|
|
def lookahead(self, state, t): |
|
"""Computes the true lookahead distribution p(z_n|z_t).""" |
|
mu = state + self.Bs[t] |
|
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t) |
|
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) |
|
|
|
def likelihood(self, observation): |
|
batch_size = tf.shape(observation)[0] |
|
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) |
|
sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1) |
|
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) |
|
|
|
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1)) |
|
|
|
def p_zt(self, prev_state, t): |
|
"""Computes the model p(z_t| z_{t-1}).""" |
|
batch_size = tf.shape(prev_state)[0] |
|
if t > 0: |
|
z_mu_p = prev_state + self.bs[t - 1] |
|
else: |
|
z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
p_zt = tf.contrib.distributions.Normal( |
|
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance)) |
|
return p_zt |
|
|
|
def generative(self, unused_observation, z_nm1): |
|
"""Computes the model's generative distribution p(z_n| z_{n-1}).""" |
|
generative_p_mu = z_nm1 + self.bs[-1] |
|
return tf.contrib.distributions.Normal( |
|
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance)) |
|
|
|
|
|
class ShortChainNonlinearP(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
variance=1.0, |
|
observation_variance=1.0, |
|
transition_type=STANDARD_TRANSITION, |
|
transition_dist=tf.contrib.distributions.Normal, |
|
dtype=tf.float32, |
|
random_seed=None): |
|
self.state_size = state_size |
|
self.num_timesteps = num_timesteps |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.variance = variance |
|
self.observation_variance = observation_variance |
|
self.transition_type = transition_type |
|
self.transition_dist = transition_dist |
|
|
|
def p_zt(self, prev_state, t): |
|
"""Computes the model p(z_t| z_{t-1}).""" |
|
batch_size = tf.shape(prev_state)[0] |
|
if t > 0: |
|
if self.transition_type == ROUND_TRANSITION: |
|
loc = tf.round(prev_state) |
|
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance)) |
|
elif self.transition_type == STANDARD_TRANSITION: |
|
loc = prev_state |
|
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance)) |
|
else: |
|
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance) |
|
|
|
p_zt = self.transition_dist( |
|
loc=loc, |
|
scale=tf.sqrt(tf.ones_like(loc) * self.variance)) |
|
return p_zt |
|
|
|
def generative(self, unused_obs, z_ni): |
|
"""Computes the model's generative distribution p(x_i| z_{ni}).""" |
|
if self.transition_type == ROUND_TRANSITION: |
|
loc = tf.round(z_ni) |
|
elif self.transition_type == STANDARD_TRANSITION: |
|
loc = z_ni |
|
generative_sigma_sq = tf.ones_like(loc) * self.observation_variance |
|
return self.transition_dist( |
|
loc=loc, scale=tf.sqrt(generative_sigma_sq)) |
|
|
|
|
|
class BimodalPriorP(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
mixing_coeff=0.5, |
|
prior_mode_mean=1, |
|
sigma_min=1e-5, |
|
variance=1.0, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
trainable=True, |
|
init_bs_to_zero=False, |
|
graph_collection_name="P_VARS"): |
|
self.state_size = state_size |
|
self.num_timesteps = num_timesteps |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.variance = variance |
|
self.mixing_coeff = mixing_coeff |
|
self.prior_mode_mean = prior_mode_mean |
|
|
|
if init_bs_to_zero: |
|
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)] |
|
else: |
|
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)] |
|
|
|
self.bs = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="b_%d" % (t + 1), |
|
initializer=initializers[t], |
|
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], |
|
trainable=trainable) for t in xrange(num_timesteps) |
|
] |
|
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) |
|
|
|
def posterior(self, observation, prev_state, t): |
|
|
|
|
|
"""Computes the true posterior p(z_t|z_{t-1}, z_n).""" |
|
|
|
|
|
mu = observation - self.Bs[t] |
|
if t > 0: |
|
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t) |
|
mu /= float(self.num_timesteps - t + 1) |
|
sigma = tf.ones_like(mu) * self.variance * ( |
|
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1)) |
|
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) |
|
|
|
def lookahead(self, state, t): |
|
"""Computes the true lookahead distribution p(z_n|z_t).""" |
|
mu = state + self.Bs[t] |
|
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t) |
|
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) |
|
|
|
def likelihood(self, observation): |
|
batch_size = tf.shape(observation)[0] |
|
sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) |
|
sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1) |
|
mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs |
|
mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs |
|
zn_pos = tf.contrib.distributions.Normal( |
|
loc=mu_pos, |
|
scale=tf.sqrt(sigma)) |
|
zn_neg = tf.contrib.distributions.Normal( |
|
loc=mu_neg, |
|
scale=tf.sqrt(sigma)) |
|
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64) |
|
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) |
|
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs) |
|
zn_dist = tf.contrib.distributions.Mixture( |
|
cat=mode_selection_dist, |
|
components=[zn_pos, zn_neg], |
|
validate_args=True) |
|
|
|
return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1)) |
|
|
|
def p_zt(self, prev_state, t): |
|
"""Computes the model p(z_t| z_{t-1}).""" |
|
batch_size = tf.shape(prev_state)[0] |
|
if t > 0: |
|
z_mu_p = prev_state + self.bs[t - 1] |
|
p_zt = tf.contrib.distributions.Normal( |
|
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance)) |
|
return p_zt |
|
else: |
|
mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean |
|
mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean |
|
z0_pos = tf.contrib.distributions.Normal( |
|
loc=mu_pos, |
|
scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance)) |
|
z0_neg = tf.contrib.distributions.Normal( |
|
loc=mu_neg, |
|
scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance)) |
|
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64) |
|
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) |
|
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs) |
|
z0_dist = tf.contrib.distributions.Mixture( |
|
cat=mode_selection_dist, |
|
components=[z0_pos, z0_neg], |
|
validate_args=False) |
|
return z0_dist |
|
|
|
def generative(self, unused_observation, z_nm1): |
|
"""Computes the model's generative distribution p(z_n| z_{n-1}).""" |
|
generative_p_mu = z_nm1 + self.bs[-1] |
|
return tf.contrib.distributions.Normal( |
|
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance)) |
|
|
|
class Model(object): |
|
|
|
def __init__(self, |
|
p, |
|
q, |
|
r, |
|
state_size, |
|
num_timesteps, |
|
dtype=tf.float32): |
|
self.p = p |
|
self.q = q |
|
self.r = r |
|
self.state_size = state_size |
|
self.num_timesteps = num_timesteps |
|
self.dtype = dtype |
|
|
|
def zero_state(self, batch_size): |
|
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
|
|
def __call__(self, prev_state, observation, t): |
|
|
|
q_zt = self.q.q_zt(observation, prev_state, t) |
|
|
|
p_zt = self.p.p_zt(prev_state, t) |
|
|
|
zt = q_zt.sample() |
|
r_xn = self.r.r_xn(zt, t) |
|
|
|
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) |
|
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) |
|
log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1) |
|
|
|
if t == self.num_timesteps - 1: |
|
generative_dist = self.p.generative(observation, zt) |
|
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1) |
|
else: |
|
log_p_x_given_z = tf.zeros_like(log_q_zt) |
|
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn) |
|
|
|
@staticmethod |
|
def create(state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
r_sigma_init=1, |
|
variance=1.0, |
|
mixing_coeff=0.5, |
|
prior_mode_mean=1.0, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
train_p=True, |
|
p_type="unimodal", |
|
q_type="normal", |
|
observation_variance=1.0, |
|
transition_type=STANDARD_TRANSITION, |
|
use_bs=True): |
|
if p_type == "unimodal": |
|
p = P(state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
dtype=dtype, |
|
random_seed=random_seed, |
|
trainable=train_p, |
|
init_bs_to_zero=not use_bs) |
|
elif p_type == "bimodal": |
|
p = BimodalPriorP( |
|
state_size, |
|
num_timesteps, |
|
mixing_coeff=mixing_coeff, |
|
prior_mode_mean=prior_mode_mean, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
dtype=dtype, |
|
random_seed=random_seed, |
|
trainable=train_p, |
|
init_bs_to_zero=not use_bs) |
|
elif "nonlinear" in p_type: |
|
if "cauchy" in p_type: |
|
trans_dist = tf.contrib.distributions.Cauchy |
|
else: |
|
trans_dist = tf.contrib.distributions.Normal |
|
p = ShortChainNonlinearP( |
|
state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
observation_variance=observation_variance, |
|
transition_type=transition_type, |
|
transition_dist=trans_dist, |
|
dtype=dtype, |
|
random_seed=random_seed |
|
) |
|
|
|
if q_type == "normal": |
|
q_class = Q |
|
elif q_type == "simple_mean": |
|
q_class = SimpleMeanQ |
|
elif q_type == "prev_state": |
|
q_class = PreviousStateQ |
|
elif q_type == "observation": |
|
q_class = ObservationQ |
|
|
|
q = q_class(state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
dtype=dtype, |
|
random_seed=random_seed, |
|
init_mu0_to_zero=not use_bs) |
|
r = R(state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
sigma_init=r_sigma_init, |
|
dtype=dtype, |
|
random_seed=random_seed) |
|
model = Model(p, q, r, state_size, num_timesteps, dtype=dtype) |
|
return model |
|
|
|
|
|
class BackwardsModel(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
dtype=tf.float32): |
|
self.state_size = state_size |
|
self.num_timesteps = num_timesteps |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.bs = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="b_%d" % (t + 1), |
|
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) |
|
] |
|
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) |
|
self.q_mus = [ |
|
snt.Linear(output_size=state_size) for _ in xrange(num_timesteps) |
|
] |
|
self.q_sigmas = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="q_sigma_%d" % (t + 1), |
|
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) |
|
] |
|
self.r_mus = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="r_mu_%d" % (t + 1), |
|
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) |
|
] |
|
self.r_sigmas = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="r_sigma_%d" % (t + 1), |
|
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) |
|
] |
|
|
|
def zero_state(self, batch_size): |
|
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
|
|
def posterior(self, unused_observation, prev_state, unused_t): |
|
|
|
return tf.contrib.distributions.Normal( |
|
loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state)) |
|
|
|
def lookahead(self, state, unused_t): |
|
|
|
return tf.contrib.distributions.Normal( |
|
loc=tf.zeros_like(state), scale=tf.zeros_like(state)) |
|
|
|
def q_zt(self, observation, next_state, t): |
|
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n).""" |
|
t_backwards = self.num_timesteps - t - 1 |
|
batch_size = tf.shape(next_state)[0] |
|
q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1)) |
|
q_sigma = tf.maximum( |
|
tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min) |
|
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) |
|
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) |
|
return q_zt |
|
|
|
def p_zt(self, prev_state, t): |
|
"""Computes the model p(z_{t+1}| z_{t}).""" |
|
t_backwards = self.num_timesteps - t - 1 |
|
z_mu_p = prev_state + self.bs[t_backwards] |
|
p_zt = tf.contrib.distributions.Normal( |
|
loc=z_mu_p, scale=tf.ones_like(z_mu_p)) |
|
return p_zt |
|
|
|
def generative(self, unused_observation, z_nm1): |
|
"""Computes the model's generative distribution p(z_n| z_{n-1}).""" |
|
generative_p_mu = z_nm1 + self.bs[-1] |
|
return tf.contrib.distributions.Normal( |
|
loc=generative_p_mu, scale=tf.ones_like(generative_p_mu)) |
|
|
|
def r(self, z_t, t): |
|
t_backwards = self.num_timesteps - t - 1 |
|
batch_size = tf.shape(z_t)[0] |
|
r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1]) |
|
r_sigma = tf.maximum( |
|
tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min) |
|
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1]) |
|
return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma)) |
|
|
|
def likelihood(self, observation): |
|
batch_size = tf.shape(observation)[0] |
|
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) |
|
sigma = tf.ones_like(mu) * (self.num_timesteps + 1) |
|
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) |
|
|
|
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1)) |
|
|
|
def __call__(self, next_state, observation, t): |
|
|
|
|
|
q_zt = self.q_zt(observation, next_state, t) |
|
|
|
zt = q_zt.sample() |
|
|
|
p_zt = self.p_zt(zt, t) |
|
|
|
if t == 0: |
|
log_p_zt = p_zt.log_prob(observation) |
|
else: |
|
log_p_zt = p_zt.log_prob(next_state) |
|
|
|
|
|
r_zt = self.r(zt, t) |
|
log_r_zt = r_zt.log_prob(zt) |
|
|
|
log_q_zt = q_zt.log_prob(zt) |
|
|
|
|
|
if t == self.num_timesteps - 1: |
|
p_z0_dist = tf.contrib.distributions.Normal( |
|
loc=tf.zeros_like(zt), scale=tf.ones_like(zt)) |
|
z0_log_prob = p_z0_dist.log_prob(zt) |
|
else: |
|
z0_log_prob = tf.zeros_like(log_q_zt) |
|
return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt) |
|
|
|
|
|
class LongChainP(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=1e-5, |
|
variance=1.0, |
|
observation_variance=1.0, |
|
observation_type=STANDARD_OBSERVATION, |
|
transition_type=STANDARD_TRANSITION, |
|
dtype=tf.float32, |
|
random_seed=None): |
|
self.state_size = state_size |
|
self.steps_per_obs = steps_per_obs |
|
self.num_obs = num_obs |
|
self.num_timesteps = steps_per_obs*num_obs + 1 |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.variance = variance |
|
self.observation_variance = observation_variance |
|
self.observation_type = observation_type |
|
self.transition_type = transition_type |
|
|
|
def likelihood(self, observations): |
|
"""Computes the model's true likelihood of the observations. |
|
|
|
Args: |
|
observations: A [batch_size, m, state_size] Tensor representing each of |
|
the m observations. |
|
Returns: |
|
logprob: The true likelihood of the observations given the model. |
|
""" |
|
raise ValueError("Likelihood is not defined for long-chain models") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def p_zt(self, prev_state, t): |
|
"""Computes the model p(z_t| z_{t-1}).""" |
|
batch_size = tf.shape(prev_state)[0] |
|
if t > 0: |
|
if self.transition_type == ROUND_TRANSITION: |
|
loc = tf.round(prev_state) |
|
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance)) |
|
elif self.transition_type == STANDARD_TRANSITION: |
|
loc = prev_state |
|
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance)) |
|
else: |
|
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance) |
|
|
|
p_zt = tf.contrib.distributions.Normal( |
|
loc=loc, |
|
scale=tf.sqrt(tf.ones_like(loc) * self.variance)) |
|
return p_zt |
|
|
|
def generative(self, z_ni, t): |
|
"""Computes the model's generative distribution p(x_i| z_{ni}).""" |
|
if self.observation_type == SQUARED_OBSERVATION: |
|
generative_mu = tf.square(z_ni) |
|
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance)) |
|
elif self.observation_type == ABS_OBSERVATION: |
|
generative_mu = tf.abs(z_ni) |
|
tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance)) |
|
elif self.observation_type == STANDARD_OBSERVATION: |
|
generative_mu = z_ni |
|
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance)) |
|
generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance |
|
return tf.contrib.distributions.Normal( |
|
loc=generative_mu, scale=tf.sqrt(generative_sigma_sq)) |
|
|
|
|
|
class LongChainQ(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=1e-5, |
|
dtype=tf.float32, |
|
random_seed=None): |
|
self.state_size = state_size |
|
self.sigma_min = sigma_min |
|
self.dtype = dtype |
|
self.steps_per_obs = steps_per_obs |
|
self.num_obs = num_obs |
|
self.num_timesteps = num_obs*steps_per_obs +1 |
|
|
|
initializers = { |
|
"w": tf.random_uniform_initializer(seed=random_seed), |
|
"b": tf.zeros_initializer |
|
} |
|
self.mus = [ |
|
snt.Linear(output_size=state_size, initializers=initializers) |
|
for t in xrange(self.num_timesteps) |
|
] |
|
self.sigmas = [ |
|
tf.get_variable( |
|
shape=[state_size], |
|
dtype=self.dtype, |
|
name="q_sigma_%d" % (t + 1), |
|
initializer=tf.random_uniform_initializer(seed=random_seed)) |
|
for t in xrange(self.num_timesteps) |
|
] |
|
|
|
def first_relevant_obs_index(self, t): |
|
return int(max((t-1)/self.steps_per_obs, 0)) |
|
|
|
def q_zt(self, observations, prev_state, t): |
|
"""Computes a distribution over z_t. |
|
|
|
Args: |
|
observations: a [batch_size, num_observations, state_size] Tensor. |
|
prev_state: a [batch_size, state_size] Tensor. |
|
t: The current timestep, an int Tensor. |
|
""" |
|
|
|
first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs)) |
|
num_relevant_observations = self.num_obs - first_relevant_obs_index |
|
observations = observations[:,first_relevant_obs_index:,:] |
|
batch_size = tf.shape(prev_state)[0] |
|
|
|
|
|
|
|
|
|
q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1) |
|
q_input = tf.reshape(q_input, |
|
[batch_size, (num_relevant_observations + 1) * self.state_size]) |
|
q_mu = self.mus[t](q_input) |
|
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) |
|
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) |
|
tf.logging.info( |
|
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format( |
|
**{"t": t, |
|
"tm1": t-1, |
|
"obsf": (first_relevant_obs_index+1)*self.steps_per_obs, |
|
"obst":self.steps_per_obs*self.num_obs})) |
|
return q_zt |
|
|
|
def summarize_weights(self): |
|
pass |
|
|
|
class LongChainR(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=1e-5, |
|
dtype=tf.float32, |
|
random_seed=None): |
|
self.state_size = state_size |
|
self.dtype = dtype |
|
self.sigma_min = sigma_min |
|
self.steps_per_obs = steps_per_obs |
|
self.num_obs = num_obs |
|
self.num_timesteps = num_obs*steps_per_obs + 1 |
|
self.sigmas = [ |
|
tf.get_variable( |
|
shape=[self.num_future_obs(t)], |
|
dtype=self.dtype, |
|
name="r_sigma_%d" % (t + 1), |
|
|
|
initializer=tf.constant_initializer(1.0)) |
|
for t in range(self.num_timesteps) |
|
] |
|
|
|
def first_future_obs_index(self, t): |
|
return int(math.floor(t / self.steps_per_obs)) |
|
|
|
def num_future_obs(self, t): |
|
return int(self.num_obs - self.first_future_obs_index(t)) |
|
|
|
def r_xn(self, z_t, t): |
|
"""Computes a distribution over the future observations given current latent |
|
state. |
|
|
|
The indexing in these messages is 1 indexed and inclusive. This is |
|
consistent with the latex documents. |
|
|
|
Args: |
|
z_t: [batch_size, state_size] Tensor |
|
t: Current timestep |
|
""" |
|
tf.logging.info( |
|
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format( |
|
**{"t": t, |
|
"start": (self.first_future_obs_index(t)+1)*self.steps_per_obs, |
|
"end": self.num_timesteps-1})) |
|
batch_size = tf.shape(z_t)[0] |
|
|
|
|
|
r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1]) |
|
|
|
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) |
|
|
|
|
|
r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size]) |
|
return tf.contrib.distributions.Normal( |
|
loc=r_mu, scale=tf.sqrt(r_sigma)) |
|
|
|
def summarize_weights(self): |
|
pass |
|
|
|
|
|
class LongChainModel(object): |
|
|
|
def __init__(self, |
|
p, |
|
q, |
|
r, |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
dtype=tf.float32, |
|
disable_r=False): |
|
self.p = p |
|
self.q = q |
|
self.r = r |
|
self.disable_r = disable_r |
|
self.state_size = state_size |
|
self.num_obs = num_obs |
|
self.steps_per_obs = steps_per_obs |
|
self.num_timesteps = steps_per_obs*num_obs + 1 |
|
self.dtype = dtype |
|
|
|
def zero_state(self, batch_size): |
|
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
|
|
def next_obs_ind(self, t): |
|
return int(math.floor(max(t-1,0)/self.steps_per_obs)) |
|
|
|
def __call__(self, prev_state, observations, t): |
|
"""Computes the importance weight for the model system. |
|
|
|
Args: |
|
prev_state: [batch_size, state_size] Tensor |
|
observations: [batch_size, num_observations, state_size] Tensor |
|
""" |
|
|
|
q_zt = self.q.q_zt(observations, prev_state, t) |
|
|
|
p_zt = self.p.p_zt(prev_state, t) |
|
|
|
zt = q_zt.sample() |
|
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) |
|
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) |
|
if not self.disable_r and t < self.num_timesteps-1: |
|
|
|
r_xn = self.r.r_xn(zt, t) |
|
log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :]) |
|
|
|
log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2]) |
|
else: |
|
log_r_xn = tf.zeros_like(log_p_zt) |
|
if t != 0 and t % self.steps_per_obs == 0: |
|
generative_dist = self.p.generative(zt, t) |
|
log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:]) |
|
log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1) |
|
else: |
|
log_p_x_given_z = tf.zeros_like(log_q_zt) |
|
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn) |
|
|
|
@staticmethod |
|
def create(state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=1e-5, |
|
variance=1.0, |
|
observation_variance=1.0, |
|
observation_type=STANDARD_OBSERVATION, |
|
transition_type=STANDARD_TRANSITION, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
disable_r=False): |
|
p = LongChainP( |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
observation_variance=observation_variance, |
|
observation_type=observation_type, |
|
transition_type=transition_type, |
|
dtype=dtype, |
|
random_seed=random_seed) |
|
q = LongChainQ( |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=sigma_min, |
|
dtype=dtype, |
|
random_seed=random_seed) |
|
r = LongChainR( |
|
state_size, |
|
num_obs, |
|
steps_per_obs, |
|
sigma_min=sigma_min, |
|
dtype=dtype, |
|
random_seed=random_seed) |
|
model = LongChainModel( |
|
p, q, r, state_size, num_obs, steps_per_obs, |
|
dtype=dtype, |
|
disable_r=disable_r) |
|
return model |
|
|
|
|
|
class RTilde(object): |
|
|
|
def __init__(self, |
|
state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
graph_collection_name="R_TILDE_VARS"): |
|
self.dtype = dtype |
|
self.sigma_min = sigma_min |
|
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), |
|
"b": tf.zeros_initializer} |
|
self.graph_collection_name=graph_collection_name |
|
|
|
def custom_getter(getter, *args, **kwargs): |
|
out = getter(*args, **kwargs) |
|
ref = tf.get_collection_ref(self.graph_collection_name) |
|
if out not in ref: |
|
ref.append(out) |
|
return out |
|
|
|
self.fns = [ |
|
snt.Linear(output_size=2*state_size, |
|
initializers=initializers, |
|
name="r_tilde_%d" % t, |
|
custom_getter=custom_getter) |
|
for t in xrange(num_timesteps) |
|
] |
|
|
|
def r_zt(self, z_t, observation, t): |
|
|
|
out = self.fns[t](tf.concat([z_t, observation], axis=1)) |
|
mu, raw_sigma_sq = tf.split(out, 2, axis=1) |
|
sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min) |
|
return mu, sigma_sq |
|
|
|
class TDModel(object): |
|
|
|
def __init__(self, |
|
p, |
|
q, |
|
r_tilde, |
|
state_size, |
|
num_timesteps, |
|
dtype=tf.float32, |
|
disable_r=False): |
|
self.p = p |
|
self.q = q |
|
self.r_tilde = r_tilde |
|
self.disable_r = disable_r |
|
self.state_size = state_size |
|
self.num_timesteps = num_timesteps |
|
self.dtype = dtype |
|
|
|
def zero_state(self, batch_size): |
|
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) |
|
|
|
def __call__(self, prev_state, observation, t): |
|
"""Computes the importance weight for the model system. |
|
|
|
Args: |
|
prev_state: [batch_size, state_size] Tensor |
|
observations: [batch_size, num_observations, state_size] Tensor |
|
""" |
|
|
|
q_zt = self.q.q_zt(observation, prev_state, t) |
|
|
|
p_zt = self.p.p_zt(prev_state, t) |
|
|
|
zt = q_zt.sample() |
|
|
|
if t < self.num_timesteps - 1: |
|
p_ztplus1 = self.p.p_zt(zt, t+1) |
|
else: |
|
p_ztplus1 = None |
|
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) |
|
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) |
|
|
|
if not self.disable_r and t < self.num_timesteps-1: |
|
|
|
r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1) |
|
else: |
|
r_tilde_mu = None |
|
r_tilde_sigma_sq = None |
|
if t == self.num_timesteps - 1: |
|
generative_dist = self.p.generative(observation, zt) |
|
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1) |
|
else: |
|
log_p_x_given_z = tf.zeros_like(log_q_zt) |
|
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, |
|
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) |
|
|
|
@staticmethod |
|
def create(state_size, |
|
num_timesteps, |
|
sigma_min=1e-5, |
|
variance=1.0, |
|
dtype=tf.float32, |
|
random_seed=None, |
|
train_p=True, |
|
p_type="unimodal", |
|
q_type="normal", |
|
mixing_coeff=0.5, |
|
prior_mode_mean=1.0, |
|
observation_variance=1.0, |
|
transition_type=STANDARD_TRANSITION, |
|
use_bs=True): |
|
if p_type == "unimodal": |
|
p = P(state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
dtype=dtype, |
|
random_seed=random_seed, |
|
trainable=train_p, |
|
init_bs_to_zero=not use_bs) |
|
elif p_type == "bimodal": |
|
p = BimodalPriorP( |
|
state_size, |
|
num_timesteps, |
|
mixing_coeff=mixing_coeff, |
|
prior_mode_mean=prior_mode_mean, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
dtype=dtype, |
|
random_seed=random_seed, |
|
trainable=train_p, |
|
init_bs_to_zero=not use_bs) |
|
elif "nonlinear" in p_type: |
|
if "cauchy" in p_type: |
|
trans_dist = tf.contrib.distributions.Cauchy |
|
else: |
|
trans_dist = tf.contrib.distributions.Normal |
|
|
|
p = ShortChainNonlinearP( |
|
state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
variance=variance, |
|
observation_variance=observation_variance, |
|
transition_type=transition_type, |
|
transition_dist=trans_dist, |
|
dtype=dtype, |
|
random_seed=random_seed |
|
) |
|
|
|
if q_type == "normal": |
|
q_class = Q |
|
elif q_type == "simple_mean": |
|
q_class = SimpleMeanQ |
|
elif q_type == "prev_state": |
|
q_class = PreviousStateQ |
|
elif q_type == "observation": |
|
q_class = ObservationQ |
|
|
|
q = q_class(state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
dtype=dtype, |
|
random_seed=random_seed, |
|
init_mu0_to_zero=not use_bs) |
|
r_tilde = RTilde( |
|
state_size, |
|
num_timesteps, |
|
sigma_min=sigma_min, |
|
dtype=dtype, |
|
random_seed=random_seed) |
|
model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype) |
|
return model |
|
|