|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from collections import namedtuple |
|
|
|
import tensorflow as tf |
|
import summary_utils as summ |
|
|
|
Loss = namedtuple("Loss", "name loss vars") |
|
Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,) |
|
|
|
|
|
def iwae(model, observation, num_timesteps, num_samples=1, |
|
summarize=False): |
|
"""Compute the IWAE evidence lower bound. |
|
|
|
Args: |
|
model: A callable that computes one timestep of the model. |
|
observation: A shape [batch_size*num_samples, state_size] Tensor |
|
containing z_n, the observation for each sequence in the batch. |
|
num_timesteps: The number of timesteps in each sequence, an integer. |
|
num_samples: The number of samples to use to compute the IWAE bound. |
|
Returns: |
|
log_p_hat: The IWAE estimator of the lower bound on the log marginal. |
|
loss: A tensor that you can perform gradient descent on to optimize the |
|
bound. |
|
maintain_ema_op: A no-op included for compatibility with FIVO. |
|
states: The sequence of states sampled. |
|
""" |
|
|
|
num_instances = tf.shape(observation)[0] |
|
batch_size = tf.cast(num_instances / num_samples, tf.int32) |
|
states = [model.zero_state(num_instances)] |
|
log_weights = [] |
|
log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype) |
|
|
|
for t in xrange(num_timesteps): |
|
|
|
(zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model( |
|
states[-1], observation, t) |
|
|
|
states.append(zt) |
|
log_weight = log_p_zt + log_p_x_given_z - log_q_zt |
|
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size]) |
|
if summarize: |
|
weight_dist = tf.contrib.distributions.Categorical( |
|
logits=tf.transpose(log_weight_acc, perm=[1, 0]), |
|
allow_nan_stats=False) |
|
weight_entropy = weight_dist.entropy() |
|
weight_entropy = tf.reduce_mean(weight_entropy) |
|
tf.summary.scalar("weight_entropy/%d" % t, weight_entropy) |
|
log_weights.append(log_weight_acc) |
|
|
|
log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) - |
|
tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps |
|
loss = -tf.reduce_mean(log_p_hat) |
|
losses = [Loss("log_p_hat", loss)] |
|
|
|
|
|
|
|
return log_p_hat, losses, tf.no_op(), states[1:], log_weights |
|
|
|
|
|
def multinomial_resampling(log_weights, states, n, b): |
|
"""Resample states with multinomial resampling. |
|
|
|
Args: |
|
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary |
|
Categorical distribution. |
|
states: A list of (b*n x d) Tensors that will be resample in from the groups |
|
of every n-th row. |
|
|
|
Returns: |
|
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling. |
|
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions. |
|
resampling_parameters: The Tensor of parameters of the resampling distribution. |
|
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions. |
|
resampling_dist: The distribution object for resampling. |
|
""" |
|
log_weights = tf.convert_to_tensor(log_weights) |
|
states = [tf.convert_to_tensor(state) for state in states] |
|
|
|
resampling_parameters = tf.transpose(log_weights, perm=[1,0]) |
|
resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters) |
|
ancestors = tf.stop_gradient( |
|
resampling_dist.sample(sample_shape=n)) |
|
log_probs = resampling_dist.log_prob(ancestors) |
|
|
|
offset = tf.expand_dims(tf.range(b), 0) |
|
ancestor_inds = tf.reshape(ancestors * b + offset, [-1]) |
|
|
|
resampled_states = [] |
|
for state in states: |
|
resampled_states.append(tf.gather(state, ancestor_inds)) |
|
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist |
|
|
|
def stratified_resampling(log_weights, states, n, b): |
|
"""Resample states with straitified resampling. |
|
|
|
Args: |
|
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary |
|
Categorical distribution. |
|
states: A list of (b*n x d) Tensors that will be resample in from the groups |
|
of every n-th row. |
|
|
|
Returns: |
|
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling. |
|
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions. |
|
resampling_parameters: The Tensor of parameters of the resampling distribution. |
|
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions. |
|
resampling_dist: The distribution object for resampling. |
|
""" |
|
log_weights = tf.convert_to_tensor(log_weights) |
|
states = [tf.convert_to_tensor(state) for state in states] |
|
|
|
log_weights = tf.transpose(log_weights, perm=[1,0]) |
|
|
|
probs = tf.nn.softmax( |
|
tf.tile(tf.expand_dims(log_weights, axis=1), |
|
[1, n, 1]) |
|
) |
|
|
|
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2) |
|
|
|
bins = tf.range(n, dtype=probs.dtype) / n |
|
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1]) |
|
|
|
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0) |
|
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1] |
|
|
|
resampling_dist = tf.contrib.distributions.Categorical( |
|
probs = resampling_parameters, |
|
allow_nan_stats=False) |
|
|
|
ancestors = tf.stop_gradient( |
|
resampling_dist.sample()) |
|
log_probs = resampling_dist.log_prob(ancestors) |
|
|
|
ancestors = tf.transpose(ancestors, perm=[1,0]) |
|
log_probs = tf.transpose(log_probs, perm=[1,0]) |
|
|
|
offset = tf.expand_dims(tf.range(b), 0) |
|
ancestor_inds = tf.reshape(ancestors * b + offset, [-1]) |
|
|
|
resampled_states = [] |
|
for state in states: |
|
resampled_states.append(tf.gather(state, ancestor_inds)) |
|
|
|
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist |
|
|
|
def systematic_resampling(log_weights, states, n, b): |
|
"""Resample states with systematic resampling. |
|
|
|
Args: |
|
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary |
|
Categorical distribution. |
|
states: A list of (b*n x d) Tensors that will be resample in from the groups |
|
of every n-th row. |
|
|
|
Returns: |
|
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling. |
|
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions. |
|
resampling_parameters: The Tensor of parameters of the resampling distribution. |
|
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions. |
|
resampling_dist: The distribution object for resampling. |
|
""" |
|
|
|
log_weights = tf.convert_to_tensor(log_weights) |
|
states = [tf.convert_to_tensor(state) for state in states] |
|
|
|
log_weights = tf.transpose(log_weights, perm=[1,0]) |
|
|
|
probs = tf.nn.softmax( |
|
tf.tile(tf.expand_dims(log_weights, axis=1), |
|
[1, n, 1]) |
|
) |
|
|
|
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2) |
|
|
|
bins = tf.range(n, dtype=probs.dtype) / n |
|
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1]) |
|
|
|
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0) |
|
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1] |
|
|
|
resampling_dist = tf.contrib.distributions.Categorical( |
|
probs=resampling_parameters, |
|
allow_nan_stats=True) |
|
|
|
U = tf.random_uniform((b, 1, 1), dtype=probs.dtype) |
|
|
|
ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1)) |
|
log_probs = resampling_dist.log_prob(ancestors) |
|
|
|
ancestors = tf.transpose(ancestors, perm=[1,0]) |
|
log_probs = tf.transpose(log_probs, perm=[1,0]) |
|
|
|
offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0) |
|
ancestor_inds = tf.reshape(ancestors * b + offset, [-1]) |
|
|
|
resampled_states = [] |
|
for state in states: |
|
resampled_states.append(tf.gather(state, ancestor_inds)) |
|
|
|
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist |
|
|
|
|
|
def log_blend(inputs, weights): |
|
"""Blends state in the log space. |
|
|
|
Args: |
|
inputs: A set of scalar states, one for each particle in each particle filter. |
|
Should be [num_samples, batch_size]. |
|
weights: A set of weights used to blend the state. Each set of weights |
|
should be of dimension [num_samples] (one weight for each previous particle). |
|
There should be one set of weights for each new particle in each particle filter. |
|
Thus the shape should be [num_samples, batch_size, num_samples] where |
|
the first axis indexes new particle and the last axis indexes old particles. |
|
Returns: |
|
blended: The blended states, a tensor of shape [num_samples, batch_size]. |
|
""" |
|
raw_max = tf.reduce_max(inputs, axis=0, keepdims=True) |
|
my_max = tf.stop_gradient( |
|
tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max)) |
|
) |
|
|
|
blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max |
|
return blended |
|
|
|
|
|
def relaxed_resampling(log_weights, states, num_samples, batch_size, |
|
log_r_x=None, blend_type="log", temperature=0.5, |
|
straight_through=False): |
|
"""Resample states with relaxed resampling. |
|
|
|
Args: |
|
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary |
|
Categorical distribution. |
|
states: A list of (b*n x d) Tensors that will be resample in from the groups |
|
of every n-th row. |
|
|
|
Returns: |
|
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling. |
|
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions. |
|
resampling_parameters: The Tensor of parameters of the resampling distribution. |
|
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions. |
|
resampling_dist: The distribution object for resampling. |
|
""" |
|
assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'." |
|
log_weights = tf.convert_to_tensor(log_weights) |
|
states = [tf.convert_to_tensor(state) for state in states] |
|
state_dim = states[0].get_shape().as_list()[-1] |
|
|
|
|
|
resampling_parameters = tf.transpose(log_weights, perm=[1, 0]) |
|
resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical( |
|
temperature, |
|
logits=resampling_parameters) |
|
|
|
|
|
|
|
|
|
|
|
ancestors = resampling_dist.sample(sample_shape=num_samples) |
|
if straight_through: |
|
|
|
hard_ancestor_indices = tf.argmax(ancestors, axis=-1) |
|
hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples, |
|
dtype=ancestors.dtype) |
|
ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors |
|
log_probs = resampling_dist.log_prob(ancestors) |
|
if log_r_x is not None and blend_type == "log": |
|
log_r_x = tf.reshape(log_r_x, [num_samples, batch_size]) |
|
log_r_x = log_blend(log_r_x, ancestors) |
|
log_r_x = tf.reshape(log_r_x, [num_samples*batch_size]) |
|
elif log_r_x is not None and blend_type == "linear": |
|
|
|
|
|
states.append(log_r_x) |
|
|
|
|
|
ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0]) |
|
resampled_states = [] |
|
for state in states: |
|
|
|
|
|
|
|
state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0]) |
|
|
|
|
|
|
|
next_state = tf.matmul(state, ancestor_inds) |
|
|
|
|
|
next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim]) |
|
resampled_states.append(next_state) |
|
|
|
new_dist = tf.contrib.distributions.Categorical( |
|
logits=resampling_parameters) |
|
|
|
if log_r_x is not None and blend_type == "linear": |
|
|
|
log_r_x = tf.squeeze(resampled_states[-1]) |
|
resampled_states = resampled_states[:-1] |
|
return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist |
|
|
|
|
|
def fivo(model, |
|
observation, |
|
num_timesteps, |
|
resampling_schedule, |
|
num_samples=1, |
|
use_resampling_grads=True, |
|
resampling_type="multinomial", |
|
resampling_temperature=0.5, |
|
aux=True, |
|
summarize=False): |
|
"""Compute the FIVO evidence lower bound. |
|
|
|
Args: |
|
model: A callable that computes one timestep of the model. |
|
observation: A shape [batch_size*num_samples, state_size] Tensor |
|
containing z_n, the observation for each sequence in the batch. |
|
num_timesteps: The number of timesteps in each sequence, an integer. |
|
resampling_schedule: A list of booleans of length num_timesteps, contains |
|
True if a resampling should occur on a specific timestep. |
|
num_samples: The number of samples to use to compute the IWAE bound. |
|
use_resampling_grads: Whether or not to include the resampling gradients |
|
in loss. |
|
resampling type: The type of resampling, one of "multinomial", "stratified", |
|
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or |
|
"systematic". |
|
resampling_temperature: A positive temperature only used for relaxed |
|
resampling. |
|
aux: If true, compute the FIVO-AUX bound. |
|
Returns: |
|
log_p_hat: The IWAE estimator of the lower bound on the log marginal. |
|
loss: A tensor that you can perform gradient descent on to optimize the |
|
bound. |
|
maintain_ema_op: An op to update the baseline ema used for the resampling |
|
gradients. |
|
states: The sequence of states sampled. |
|
""" |
|
|
|
num_instances = tf.cast(tf.shape(observation)[0], tf.int32) |
|
batch_size = tf.cast(num_instances / num_samples, tf.int32) |
|
states = [model.zero_state(num_instances)] |
|
prev_state = states[0] |
|
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype) |
|
prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype) |
|
log_weights = [] |
|
log_weights_all = [] |
|
log_p_hats = [] |
|
resampling_log_probs = [] |
|
for t in xrange(num_timesteps): |
|
|
|
(zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model( |
|
prev_state, observation, t) |
|
|
|
states.append(zt) |
|
log_weight = log_p_zt + log_p_x_given_z - log_q_zt |
|
if aux: |
|
if t == num_timesteps - 1: |
|
log_weight -= prev_log_r_zt |
|
else: |
|
log_weight += log_r_zt - prev_log_r_zt |
|
prev_log_r_zt = log_r_zt |
|
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size]) |
|
log_weights_all.append(log_weight_acc) |
|
if resampling_schedule[t]: |
|
|
|
|
|
to_resample = [states[-1]] |
|
if aux and "relaxed" not in resampling_type: |
|
to_resample.append(prev_log_r_zt) |
|
|
|
|
|
if resampling_type == "multinomial": |
|
(resampled, |
|
resampling_log_prob, |
|
_, _, _) = multinomial_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size) |
|
elif resampling_type == "stratified": |
|
(resampled, |
|
resampling_log_prob, |
|
_, _, _) = stratified_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size) |
|
elif resampling_type == "systematic": |
|
(resampled, |
|
resampling_log_prob, |
|
_, _, _) = systematic_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size) |
|
elif "relaxed" in resampling_type: |
|
if aux: |
|
if resampling_type == "relaxed-logblend": |
|
(resampled, |
|
resampling_log_prob, |
|
prev_log_r_zt, |
|
_, _, _) = relaxed_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size, |
|
temperature=resampling_temperature, |
|
log_r_x=prev_log_r_zt, |
|
blend_type="log") |
|
elif resampling_type == "relaxed-linearblend": |
|
(resampled, |
|
resampling_log_prob, |
|
prev_log_r_zt, |
|
_, _, _) = relaxed_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size, |
|
temperature=resampling_temperature, |
|
log_r_x=prev_log_r_zt, |
|
blend_type="linear") |
|
elif resampling_type == "relaxed-stateblend": |
|
(resampled, |
|
resampling_log_prob, |
|
_, _, _, _) = relaxed_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size, |
|
temperature=resampling_temperature) |
|
|
|
prev_r_zt = model.r.r_xn(resampled[0], t) |
|
prev_log_r_zt = tf.reduce_sum( |
|
prev_r_zt.log_prob(observation), axis=[1]) |
|
elif resampling_type == "relaxed-stateblend-st": |
|
(resampled, |
|
resampling_log_prob, |
|
_, _, _, _) = relaxed_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size, |
|
temperature=resampling_temperature, |
|
straight_through=True) |
|
|
|
prev_r_zt = model.r.r_xn(resampled[0], t) |
|
prev_log_r_zt = tf.reduce_sum( |
|
prev_r_zt.log_prob(observation), axis=[1]) |
|
else: |
|
(resampled, |
|
resampling_log_prob, |
|
_, _, _, _) = relaxed_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size, |
|
temperature=resampling_temperature) |
|
|
|
|
|
|
|
|
|
|
|
resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0)) |
|
prev_state = resampled[0] |
|
if aux and "relaxed" not in resampling_type: |
|
|
|
|
|
prev_log_r_zt = tf.squeeze(resampled[1]) |
|
|
|
|
|
log_p_hats.append( |
|
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log( |
|
tf.cast(num_samples, dtype=observation.dtype))) |
|
|
|
log_weights.append(log_weight_acc) |
|
log_weight_acc = tf.zeros_like(log_weight_acc) |
|
else: |
|
prev_state = states[-1] |
|
|
|
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) - |
|
tf.log(tf.cast(num_samples, dtype=observation.dtype))) |
|
|
|
if len(log_p_hats) > 0: |
|
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update |
|
else: |
|
log_p_hat = final_update |
|
|
|
if use_resampling_grads and any(resampling_schedule): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rewards = tf.stop_gradient( |
|
tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0)) |
|
batch_avg_rewards = tf.reduce_mean(rewards, axis=1) |
|
|
|
|
|
baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94) |
|
maintain_baseline_op = baseline_ema.apply([batch_avg_rewards]) |
|
baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1) |
|
centered_rewards = rewards - baseline |
|
if summarize: |
|
summ.summarize_learning_signal(rewards, "rewards") |
|
summ.summarize_learning_signal(centered_rewards, "centered_rewards") |
|
|
|
resampling_grads = tf.reduce_sum( |
|
tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0) |
|
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps), |
|
Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)] |
|
else: |
|
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)] |
|
maintain_baseline_op = tf.no_op() |
|
|
|
log_p_hat /= num_timesteps |
|
|
|
return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all |
|
|
|
|
|
def fivo_aux_td( |
|
model, |
|
observation, |
|
num_timesteps, |
|
resampling_schedule, |
|
num_samples=1, |
|
summarize=False): |
|
"""Compute the FIVO_AUX evidence lower bound.""" |
|
|
|
num_instances = tf.cast(tf.shape(observation)[0], tf.int32) |
|
batch_size = tf.cast(num_instances / num_samples, tf.int32) |
|
states = [model.zero_state(num_instances)] |
|
prev_state = states[0] |
|
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype) |
|
prev_log_r = tf.zeros([num_instances], dtype=observation.dtype) |
|
|
|
log_rs = [] |
|
|
|
r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)] |
|
log_r_tildes = [] |
|
log_p_xs = [] |
|
|
|
log_weights = [] |
|
|
|
log_weights_all = [] |
|
log_p_hats = [] |
|
for t in xrange(num_timesteps): |
|
|
|
|
|
|
|
|
|
|
|
(zt, log_q_zt, log_p_zt, log_p_x_given_z, |
|
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t) |
|
|
|
|
|
log_weight = log_p_zt + log_p_x_given_z - log_q_zt |
|
|
|
|
|
if t == num_timesteps - 1: |
|
log_r = tf.zeros_like(prev_log_r) |
|
else: |
|
p_mu = p_ztplus1.mean() |
|
p_sigma_sq = p_ztplus1.variance() |
|
log_r = (tf.log(r_tilde_sigma_sq) - |
|
tf.log(r_tilde_sigma_sq + p_sigma_sq) - |
|
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq)) |
|
log_r = 0.5*tf.reduce_sum(log_r, axis=-1) |
|
|
|
|
|
log_weight += log_r - prev_log_r |
|
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size]) |
|
|
|
|
|
states.append(zt) |
|
log_weights_all.append(log_weight_acc) |
|
log_p_xs.append(log_p_x_given_z) |
|
log_rs.append(log_r) |
|
|
|
|
|
prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1] |
|
prev_log_r_tilde = -0.5*tf.reduce_sum( |
|
tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1) |
|
|
|
|
|
log_r_tildes.append(prev_log_r_tilde) |
|
|
|
|
|
if resampling_schedule[t]: |
|
|
|
if t < num_timesteps - 1: |
|
to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq] |
|
else: |
|
to_resample = [zt, log_r] |
|
(resampled, |
|
_, _, _, _) = multinomial_resampling(log_weight_acc, |
|
to_resample, |
|
num_samples, |
|
batch_size) |
|
prev_state = resampled[0] |
|
|
|
|
|
prev_log_r = tf.squeeze(resampled[1]) |
|
if t < num_timesteps -1: |
|
r_tilde_params.append((resampled[2], resampled[3])) |
|
|
|
|
|
log_p_hats.append( |
|
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log( |
|
tf.cast(num_samples, dtype=observation.dtype))) |
|
|
|
log_weights.append(log_weight_acc) |
|
log_weight_acc = tf.zeros_like(log_weight_acc) |
|
else: |
|
prev_state = zt |
|
prev_log_r = log_r |
|
if t < num_timesteps - 1: |
|
r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq)) |
|
|
|
|
|
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) - |
|
tf.log(tf.cast(num_samples, dtype=observation.dtype))) |
|
|
|
if len(log_p_hats) > 0: |
|
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update |
|
else: |
|
log_p_hat = final_update |
|
|
|
|
|
|
|
|
|
log_p_x = tf.reshape(tf.stack(log_p_xs), |
|
[num_timesteps, num_samples, batch_size]) |
|
|
|
|
|
log_r = tf.reshape(tf.stack(log_rs), |
|
[num_timesteps, num_samples, batch_size]) |
|
|
|
log_r_tilde = tf.reshape(tf.stack(log_r_tildes), |
|
[num_timesteps, num_samples, batch_size]) |
|
log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1, |
|
keepdims=True) |
|
bellman_sos = tf.reduce_mean(tf.square( |
|
log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1]) |
|
bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps |
|
tf.summary.scalar("bellman_loss", bellman_loss) |
|
|
|
if len(tf.get_collection("LOG_P_HAT_VARS")) == 0: |
|
log_p_hat_collection = list(set(tf.trainable_variables()) - |
|
set(tf.get_collection("R_TILDE_VARS"))) |
|
for v in log_p_hat_collection: |
|
tf.add_to_collection("LOG_P_HAT_VARS", v) |
|
|
|
log_p_hat /= num_timesteps |
|
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"), |
|
Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")] |
|
|
|
return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all |
|
|