NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import namedtuple
import tensorflow as tf
import summary_utils as summ
Loss = namedtuple("Loss", "name loss vars")
Loss.__new__.__defaults__ = (tf.GraphKeys.TRAINABLE_VARIABLES,)
def iwae(model, observation, num_timesteps, num_samples=1,
summarize=False):
"""Compute the IWAE evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
num_samples: The number of samples to use to compute the IWAE bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: A no-op included for compatibility with FIVO.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.shape(observation)[0]
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
log_weights = []
log_weight_acc = tf.zeros([num_samples, batch_size], dtype=observation.dtype)
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, _) = model(
states[-1], observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
if summarize:
weight_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weight_acc, perm=[1, 0]),
allow_nan_stats=False)
weight_entropy = weight_dist.entropy()
weight_entropy = tf.reduce_mean(weight_entropy)
tf.summary.scalar("weight_entropy/%d" % t, weight_entropy)
log_weights.append(log_weight_acc)
# Compute the lower bound on the log evidence.
log_p_hat = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, observation.dtype))) / num_timesteps
loss = -tf.reduce_mean(log_p_hat)
losses = [Loss("log_p_hat", loss)]
# we clip off the initial state before returning.
# there are no emas for iwae, so we return a noop for that
return log_p_hat, losses, tf.no_op(), states[1:], log_weights
def multinomial_resampling(log_weights, states, n, b):
"""Resample states with multinomial resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
resampling_parameters = tf.transpose(log_weights, perm=[1,0])
resampling_dist = tf.contrib.distributions.Categorical(logits=resampling_parameters)
ancestors = tf.stop_gradient(
resampling_dist.sample(sample_shape=n))
log_probs = resampling_dist.log_prob(ancestors)
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def stratified_resampling(log_weights, states, n, b):
"""Resample states with straitified resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs = resampling_parameters,
allow_nan_stats=False)
ancestors = tf.stop_gradient(
resampling_dist.sample())
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def systematic_resampling(log_weights, states, n, b):
"""Resample states with systematic resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b) Tensor of integral indices representing the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
log_weights = tf.transpose(log_weights, perm=[1,0])
probs = tf.nn.softmax(
tf.tile(tf.expand_dims(log_weights, axis=1),
[1, n, 1])
)
cdfs = tf.concat([tf.zeros((b,n,1), dtype=probs.dtype), tf.cumsum(probs, axis=2)], 2)
bins = tf.range(n, dtype=probs.dtype) / n
bins = tf.tile(tf.reshape(bins, [1,-1,1]), [b,1,n+1])
strat_cdfs = tf.minimum(tf.maximum((cdfs - bins) * n, 0.0), 1.0)
resampling_parameters = strat_cdfs[:,:,1:] - strat_cdfs[:,:,:-1]
resampling_dist = tf.contrib.distributions.Categorical(
probs=resampling_parameters,
allow_nan_stats=True)
U = tf.random_uniform((b, 1, 1), dtype=probs.dtype)
ancestors = tf.stop_gradient(tf.reduce_sum(tf.to_float(U > strat_cdfs[:,:,1:]), axis=-1))
log_probs = resampling_dist.log_prob(ancestors)
ancestors = tf.transpose(ancestors, perm=[1,0])
log_probs = tf.transpose(log_probs, perm=[1,0])
offset = tf.expand_dims(tf.range(b, dtype=probs.dtype), 0)
ancestor_inds = tf.reshape(ancestors * b + offset, [-1])
resampled_states = []
for state in states:
resampled_states.append(tf.gather(state, ancestor_inds))
return resampled_states, log_probs, resampling_parameters, ancestors, resampling_dist
def log_blend(inputs, weights):
"""Blends state in the log space.
Args:
inputs: A set of scalar states, one for each particle in each particle filter.
Should be [num_samples, batch_size].
weights: A set of weights used to blend the state. Each set of weights
should be of dimension [num_samples] (one weight for each previous particle).
There should be one set of weights for each new particle in each particle filter.
Thus the shape should be [num_samples, batch_size, num_samples] where
the first axis indexes new particle and the last axis indexes old particles.
Returns:
blended: The blended states, a tensor of shape [num_samples, batch_size].
"""
raw_max = tf.reduce_max(inputs, axis=0, keepdims=True)
my_max = tf.stop_gradient(
tf.where(tf.is_finite(raw_max), raw_max, tf.zeros_like(raw_max))
)
# Don't ask.
blended = tf.log(tf.einsum("ijk,kj->ij", weights, tf.exp(inputs - raw_max))) + my_max
return blended
def relaxed_resampling(log_weights, states, num_samples, batch_size,
log_r_x=None, blend_type="log", temperature=0.5,
straight_through=False):
"""Resample states with relaxed resampling.
Args:
log_weights: A (n x b) Tensor representing a batch of b logits for n-ary
Categorical distribution.
states: A list of (b*n x d) Tensors that will be resample in from the groups
of every n-th row.
Returns:
resampled_states: A list of (b*n x d) Tensors resampled via stratified sampling.
log_probs: A (n x b) Tensor of the log probabilities of the ancestry decisions.
resampling_parameters: The Tensor of parameters of the resampling distribution.
ancestors: An (n x b x n) Tensor of relaxed one hot representations of the ancestry decisions.
resampling_dist: The distribution object for resampling.
"""
assert blend_type in ["log", "linear"], "Blend type must be 'log' or 'linear'."
log_weights = tf.convert_to_tensor(log_weights)
states = [tf.convert_to_tensor(state) for state in states]
state_dim = states[0].get_shape().as_list()[-1]
# weights are num_samples by batch_size, so we transpose to get a
# set of batch_size distributions over [0,num_samples).
resampling_parameters = tf.transpose(log_weights, perm=[1, 0])
resampling_dist = tf.contrib.distributions.RelaxedOneHotCategorical(
temperature,
logits=resampling_parameters)
# sample num_samples samples from the distribution, resulting in a
# [num_samples, batch_size, num_samples] Tensor that represents a set of
# [num_samples, batch_size] blending weights. The dimensions represent
# [sample index, batch index, blending weight index]
ancestors = resampling_dist.sample(sample_shape=num_samples)
if straight_through:
# Forward pass discrete choices, backwards pass soft choices
hard_ancestor_indices = tf.argmax(ancestors, axis=-1)
hard_ancestors = tf.one_hot(hard_ancestor_indices, num_samples,
dtype=ancestors.dtype)
ancestors = tf.stop_gradient(hard_ancestors - ancestors) + ancestors
log_probs = resampling_dist.log_prob(ancestors)
if log_r_x is not None and blend_type == "log":
log_r_x = tf.reshape(log_r_x, [num_samples, batch_size])
log_r_x = log_blend(log_r_x, ancestors)
log_r_x = tf.reshape(log_r_x, [num_samples*batch_size])
elif log_r_x is not None and blend_type == "linear":
# If blend type is linear just add log_r to the states that will be blended
# linearly.
states.append(log_r_x)
# transpose the 'indices' to be [batch_index, blending weight index, sample index]
ancestor_inds = tf.transpose(ancestors, perm=[1, 2, 0])
resampled_states = []
for state in states:
# state is currently [num_samples * batch_size, state_dim] so we reshape
# to [num_samples, batch_size, state_dim] and then transpose to
# [batch_size, state_size, num_samples]
state = tf.transpose(tf.reshape(state, [num_samples, batch_size, -1]), perm=[1, 2, 0])
# state is now (batch_size, state_size, num_samples)
# and ancestor is (batch index, blending weight index, sample index)
# multiplying these gives a matrix of size [batch_size, state_size, num_samples]
next_state = tf.matmul(state, ancestor_inds)
# transpose the state to be [num_samples, batch_size, state_size]
# and then reshape it to match the state format.
next_state = tf.reshape(tf.transpose(next_state, perm=[2,0,1]), [num_samples*batch_size, state_dim])
resampled_states.append(next_state)
new_dist = tf.contrib.distributions.Categorical(
logits=resampling_parameters)
if log_r_x is not None and blend_type == "linear":
# If blend type is linear pop off log_r that we added to the states.
log_r_x = tf.squeeze(resampled_states[-1])
resampled_states = resampled_states[:-1]
return resampled_states, log_probs, log_r_x, resampling_parameters, ancestors, new_dist
def fivo(model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
use_resampling_grads=True,
resampling_type="multinomial",
resampling_temperature=0.5,
aux=True,
summarize=False):
"""Compute the FIVO evidence lower bound.
Args:
model: A callable that computes one timestep of the model.
observation: A shape [batch_size*num_samples, state_size] Tensor
containing z_n, the observation for each sequence in the batch.
num_timesteps: The number of timesteps in each sequence, an integer.
resampling_schedule: A list of booleans of length num_timesteps, contains
True if a resampling should occur on a specific timestep.
num_samples: The number of samples to use to compute the IWAE bound.
use_resampling_grads: Whether or not to include the resampling gradients
in loss.
resampling type: The type of resampling, one of "multinomial", "stratified",
"relaxed-logblend", "relaxed-linearblend", "relaxed-stateblend", or
"systematic".
resampling_temperature: A positive temperature only used for relaxed
resampling.
aux: If true, compute the FIVO-AUX bound.
Returns:
log_p_hat: The IWAE estimator of the lower bound on the log marginal.
loss: A tensor that you can perform gradient descent on to optimize the
bound.
maintain_ema_op: An op to update the baseline ema used for the resampling
gradients.
states: The sequence of states sampled.
"""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r_zt = tf.zeros([num_instances], dtype=observation.dtype)
log_weights = []
log_weights_all = []
log_p_hats = []
resampling_log_probs = []
for t in xrange(num_timesteps):
# run the model for one timestep
(zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_zt) = model(
prev_state, observation, t)
# update accumulators
states.append(zt)
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
if aux:
if t == num_timesteps - 1:
log_weight -= prev_log_r_zt
else:
log_weight += log_r_zt - prev_log_r_zt
prev_log_r_zt = log_r_zt
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
log_weights_all.append(log_weight_acc)
if resampling_schedule[t]:
# These objects will be resampled
to_resample = [states[-1]]
if aux and "relaxed" not in resampling_type:
to_resample.append(prev_log_r_zt)
# do the resampling
if resampling_type == "multinomial":
(resampled,
resampling_log_prob,
_, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "stratified":
(resampled,
resampling_log_prob,
_, _, _) = stratified_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif resampling_type == "systematic":
(resampled,
resampling_log_prob,
_, _, _) = systematic_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
elif "relaxed" in resampling_type:
if aux:
if resampling_type == "relaxed-logblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="log")
elif resampling_type == "relaxed-linearblend":
(resampled,
resampling_log_prob,
prev_log_r_zt,
_, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
log_r_x=prev_log_r_zt,
blend_type="linear")
elif resampling_type == "relaxed-stateblend":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
elif resampling_type == "relaxed-stateblend-st":
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature,
straight_through=True)
# Calculate prev_log_r_zt from the post-resampling state
prev_r_zt = model.r.r_xn(resampled[0], t)
prev_log_r_zt = tf.reduce_sum(
prev_r_zt.log_prob(observation), axis=[1])
else:
(resampled,
resampling_log_prob,
_, _, _, _) = relaxed_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size,
temperature=resampling_temperature)
#if summarize:
# resampling_entropy = resampling_dist.entropy()
# resampling_entropy = tf.reduce_mean(resampling_entropy)
# tf.summary.scalar("weight_entropy/%d" % t, resampling_entropy)
resampling_log_probs.append(tf.reduce_sum(resampling_log_prob, axis=0))
prev_state = resampled[0]
if aux and "relaxed" not in resampling_type:
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt should always be [num_instances]
prev_log_r_zt = tf.squeeze(resampled[1])
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = states[-1]
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
if use_resampling_grads and any(resampling_schedule):
# compute the rewards
# cumsum([a, b, c]) => [a, a+b, a+b+c]
# learning signal at timestep t is
# [sum from i=t+1 to T of log_p_hat_i for t=1:T]
# so we will compute (sum from i=1 to T of log_p_hat_i)
# and at timestep t will subtract off (sum from i=1 to t of log_p_hat_i)
# rewards is a [num_resampling_events, batch_size] Tensor
rewards = tf.stop_gradient(
tf.expand_dims(log_p_hat, 0) - tf.cumsum(log_p_hats, axis=0))
batch_avg_rewards = tf.reduce_mean(rewards, axis=1)
# compute ema baseline.
# centered_rewards is [num_resampling_events, batch_size]
baseline_ema = tf.train.ExponentialMovingAverage(decay=0.94)
maintain_baseline_op = baseline_ema.apply([batch_avg_rewards])
baseline = tf.expand_dims(baseline_ema.average(batch_avg_rewards), 1)
centered_rewards = rewards - baseline
if summarize:
summ.summarize_learning_signal(rewards, "rewards")
summ.summarize_learning_signal(centered_rewards, "centered_rewards")
# compute the loss tensor.
resampling_grads = tf.reduce_sum(
tf.stop_gradient(centered_rewards) * resampling_log_probs, axis=0)
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps),
Loss("resampling_grads", -tf.reduce_mean(resampling_grads)/num_timesteps)]
else:
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat)/num_timesteps)]
maintain_baseline_op = tf.no_op()
log_p_hat /= num_timesteps
# we clip off the initial state before returning.
return log_p_hat, losses, maintain_baseline_op, states[1:], log_weights_all
def fivo_aux_td(
model,
observation,
num_timesteps,
resampling_schedule,
num_samples=1,
summarize=False):
"""Compute the FIVO_AUX evidence lower bound."""
# Initialization
num_instances = tf.cast(tf.shape(observation)[0], tf.int32)
batch_size = tf.cast(num_instances / num_samples, tf.int32)
states = [model.zero_state(num_instances)]
prev_state = states[0]
log_weight_acc = tf.zeros(shape=[num_samples, batch_size], dtype=observation.dtype)
prev_log_r = tf.zeros([num_instances], dtype=observation.dtype)
# must be pre-resampling
log_rs = []
# must be post-resampling
r_tilde_params = [model.r_tilde.r_zt(states[0], observation, 0)]
log_r_tildes = []
log_p_xs = []
# contains the weight at each timestep before resampling only on resampling timesteps
log_weights = []
# contains weight at each timestep before resampling
log_weights_all = []
log_p_hats = []
for t in xrange(num_timesteps):
# run the model for one timestep
# zt is state, [num_instances, state_dim]
# log_q_zt, log_p_x_given_z is [num_instances]
# r_tilde_mu, r_tilde_sigma is [num_instances, state_dim]
# p_ztplus1 is a normal distribution on [num_instances, state_dim]
(zt, log_q_zt, log_p_zt, log_p_x_given_z,
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) = model(prev_state, observation, t)
# Compute the log weight without log r.
log_weight = log_p_zt + log_p_x_given_z - log_q_zt
# Compute log r.
if t == num_timesteps - 1:
log_r = tf.zeros_like(prev_log_r)
else:
p_mu = p_ztplus1.mean()
p_sigma_sq = p_ztplus1.variance()
log_r = (tf.log(r_tilde_sigma_sq) -
tf.log(r_tilde_sigma_sq + p_sigma_sq) -
tf.square(r_tilde_mu - p_mu)/(r_tilde_sigma_sq + p_sigma_sq))
log_r = 0.5*tf.reduce_sum(log_r, axis=-1)
#log_weight += tf.stop_gradient(log_r - prev_log_r)
log_weight += log_r - prev_log_r
log_weight_acc += tf.reshape(log_weight, [num_samples, batch_size])
# Update accumulators
states.append(zt)
log_weights_all.append(log_weight_acc)
log_p_xs.append(log_p_x_given_z)
log_rs.append(log_r)
# Compute log_r_tilde as [num_instances] Tensor.
prev_r_tilde_mu, prev_r_tilde_sigma_sq = r_tilde_params[-1]
prev_log_r_tilde = -0.5*tf.reduce_sum(
tf.square(zt - prev_r_tilde_mu)/prev_r_tilde_sigma_sq, axis=-1)
#tf.square(tf.stop_gradient(zt) - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
#tf.square(zt - r_tilde_mu)/r_tilde_sigma_sq, axis=-1)
log_r_tildes.append(prev_log_r_tilde)
# optionally resample
if resampling_schedule[t]:
# These objects will be resampled
if t < num_timesteps - 1:
to_resample = [zt, log_r, r_tilde_mu, r_tilde_sigma_sq]
else:
to_resample = [zt, log_r]
(resampled,
_, _, _, _) = multinomial_resampling(log_weight_acc,
to_resample,
num_samples,
batch_size)
prev_state = resampled[0]
# Squeeze out the extra dim potentially added by resampling.
# prev_log_r_zt and log_r_tilde should always be [num_instances]
prev_log_r = tf.squeeze(resampled[1])
if t < num_timesteps -1:
r_tilde_params.append((resampled[2], resampled[3]))
# Update the log p hat estimate, taking a log sum exp over the sample
# dimension. The appended tensor is [batch_size].
log_p_hats.append(
tf.reduce_logsumexp(log_weight_acc, axis=0) - tf.log(
tf.cast(num_samples, dtype=observation.dtype)))
# reset the weights
log_weights.append(log_weight_acc)
log_weight_acc = tf.zeros_like(log_weight_acc)
else:
prev_state = zt
prev_log_r = log_r
if t < num_timesteps - 1:
r_tilde_params.append((r_tilde_mu, r_tilde_sigma_sq))
# Compute the final weight update. If we just resampled this will be zero.
final_update = (tf.reduce_logsumexp(log_weight_acc, axis=0) -
tf.log(tf.cast(num_samples, dtype=observation.dtype)))
# If we ever resampled, then sum up the previous log p hat terms
if len(log_p_hats) > 0:
log_p_hat = tf.reduce_sum(log_p_hats, axis=0) + final_update
else: # otherwise, log_p_hat only comes from the final update
log_p_hat = final_update
# Compute the bellman loss.
# Will remove the first timestep as it is not used.
# log p(x_t|z_t) is in row t-1.
log_p_x = tf.reshape(tf.stack(log_p_xs),
[num_timesteps, num_samples, batch_size])
# log r_t is contained in row t-1.
# last column is zeros (because at timestep T (num_timesteps) r is 1.
log_r = tf.reshape(tf.stack(log_rs),
[num_timesteps, num_samples, batch_size])
# [num_timesteps, num_instances]. log r_tilde_t is in row t-1.
log_r_tilde = tf.reshape(tf.stack(log_r_tildes),
[num_timesteps, num_samples, batch_size])
log_lambda = tf.reduce_mean(log_r_tilde - log_p_x - log_r, axis=1,
keepdims=True)
bellman_sos = tf.reduce_mean(tf.square(
log_r_tilde - tf.stop_gradient(log_lambda + log_p_x + log_r)), axis=[0, 1])
bellman_loss = tf.reduce_mean(bellman_sos)/num_timesteps
tf.summary.scalar("bellman_loss", bellman_loss)
if len(tf.get_collection("LOG_P_HAT_VARS")) == 0:
log_p_hat_collection = list(set(tf.trainable_variables()) -
set(tf.get_collection("R_TILDE_VARS")))
for v in log_p_hat_collection:
tf.add_to_collection("LOG_P_HAT_VARS", v)
log_p_hat /= num_timesteps
losses = [Loss("log_p_hat", -tf.reduce_mean(log_p_hat), "LOG_P_HAT_VARS"),
Loss("bellman_loss", bellman_loss, "R_TILDE_VARS")]
return log_p_hat, losses, tf.no_op(), states[1:], log_weights_all