|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for plotting and summarizing. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import matplotlib.gridspec as gridspec |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import scipy |
|
|
|
import tensorflow as tf |
|
|
|
import models |
|
|
|
|
|
def summarize_ess(weights, only_last_timestep=False): |
|
"""Plots the effective sample size. |
|
|
|
Args: |
|
weights: List of length num_timesteps Tensors of shape |
|
[num_samples, batch_size] |
|
""" |
|
num_timesteps = len(weights) |
|
batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64) |
|
for i in range(num_timesteps): |
|
if only_last_timestep and i < num_timesteps-1: continue |
|
|
|
w = tf.nn.softmax(weights[i], dim=0) |
|
centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True) |
|
variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1) |
|
ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0)) |
|
tf.summary.scalar("ess/%d" % i, ess) |
|
tf.summary.scalar("ese/%d" % i, ess / batch_size) |
|
tf.summary.scalar("weight_variance/%d" % i, variance) |
|
|
|
|
|
def summarize_particles(states, weights, observation, model): |
|
"""Plots particle locations and weights. |
|
|
|
Args: |
|
states: List of length num_timesteps Tensors of shape |
|
[batch_size*num_particles, state_size]. |
|
weights: List of length num_timesteps Tensors of shape [num_samples, |
|
batch_size] |
|
observation: Tensor of shape [batch_size*num_samples, state_size] |
|
""" |
|
num_timesteps = len(weights) |
|
num_samples, batch_size = weights[0].get_shape().as_list() |
|
|
|
q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0) |
|
q0_loc = q0_dist.loc[0:batch_size, 0] |
|
q0_scale = q0_dist.scale[0:batch_size, 0] |
|
|
|
post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance, |
|
tf.reduce_sum(model.p.bs), model.p.num_timesteps) |
|
|
|
|
|
states = tf.stack(states) |
|
weights = tf.stack(weights) |
|
|
|
weights = tf.nn.softmax(weights, dim=1) |
|
states = tf.reshape(states, tf.shape(weights)) |
|
|
|
ess = 1./tf.reduce_sum(tf.square(weights), axis=1) |
|
|
|
def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post): |
|
""" |
|
states: [time, num_samples, batch_size] |
|
weights [time, num_samples, batch_size] |
|
observation: [batch_size, 1] |
|
q0: ([batch_size], [batch_size]) |
|
post: ... |
|
""" |
|
num_timesteps, _, batch_size = states_batch.shape |
|
plots = [] |
|
for i in range(batch_size): |
|
states = states_batch[:,:,i] |
|
weights = weights_batch[:,:,i] |
|
observation = observation_batch[i] |
|
ess = ess_batch[:,i] |
|
q0_loc = q0[0][i] |
|
q0_scale = q0[1][i] |
|
|
|
fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2)) |
|
|
|
|
|
|
|
|
|
bar_rows = 8 |
|
num_rows = (num_timesteps + 1) * (bar_rows + 1) |
|
gs = gridspec.GridSpec(num_rows, 1) |
|
|
|
|
|
prior_lims = (post[1] * -2, post[1] * 2) |
|
q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale), |
|
scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale)) |
|
state_width = states.max() - states.min() |
|
state_lims = (states.min() - state_width * 0.15, |
|
states.max() + state_width * 0.15) |
|
|
|
lims = (min(prior_lims[0], q_lims[0], state_lims[0]), |
|
max(prior_lims[1], q_lims[1], state_lims[1])) |
|
|
|
z0 = np.arange(lims[0], lims[1], 0.1) |
|
alpha, pos_mu, sigma_sq, B, T = post |
|
neg_mu = -pos_mu |
|
scale = np.sqrt((T + 1) * sigma_sq) |
|
p_zn = ( |
|
alpha * scipy.stats.norm.pdf( |
|
observation, loc=pos_mu + B, scale=scale) + (1 - alpha) * |
|
scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale)) |
|
p_z0 = ( |
|
alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq)) |
|
+ (1 - alpha) * scipy.stats.norm.pdf( |
|
z0, loc=neg_mu, scale=np.sqrt(sigma_sq))) |
|
p_zn_given_z0 = scipy.stats.norm.pdf( |
|
observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq)) |
|
post_z0 = (p_z0 * p_zn_given_z0) / p_zn |
|
|
|
q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale) |
|
ax = plt.subplot(gs[0:bar_rows, :]) |
|
ax.plot(z0, q_z0, color="blue") |
|
ax.plot(z0, post_z0, color="green") |
|
ax.plot(z0, p_z0, color="red") |
|
ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10}) |
|
|
|
ax.set_xticks([]) |
|
ax.set_xlim(*lims) |
|
|
|
|
|
for t in range(num_timesteps): |
|
start = (t + 1) * (bar_rows + 1) |
|
ax1 = plt.subplot(gs[start:start + bar_rows, :]) |
|
ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none") |
|
ax1.set_ylabel("t=%d" % t) |
|
ax1.set_xticks([]) |
|
ax1.grid(True, which="both") |
|
ax1.set_xlim(*lims) |
|
|
|
ax1.axvline(x=observation, color="red", linestyle="dashed") |
|
|
|
ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t], |
|
ha='center', va='center', transform=ax1.transAxes) |
|
|
|
|
|
ax2.hist2d( |
|
states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys") |
|
ax2.grid(False) |
|
ax2.set_yticks([]) |
|
ax2.set_xlim(*lims) |
|
if t != num_timesteps - 1: |
|
ax2.set_xticks([]) |
|
|
|
fig.canvas.draw() |
|
p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") |
|
plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,))) |
|
plt.close(fig) |
|
return np.stack(plots) |
|
|
|
plots = tf.py_func(_plot_states, |
|
[states, weights, observation, ess, (q0_loc, q0_scale), post], |
|
[tf.uint8])[0] |
|
tf.summary.image("states", plots, 5, collections=["infrequent_summaries"]) |
|
|
|
|
|
def plot_weights(weights, resampled=None): |
|
"""Plots the weights and effective sample size from an SMC rollout. |
|
|
|
Args: |
|
weights: [num_timesteps, num_samples, batch_size] importance weights |
|
resampled: [num_timesteps] 0/1 indicating if resampling ocurred |
|
""" |
|
weights = tf.convert_to_tensor(weights) |
|
|
|
def _make_plots(weights, resampled): |
|
num_timesteps, num_samples, batch_size = weights.shape |
|
plots = [] |
|
for i in range(batch_size): |
|
fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4)) |
|
axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i])) |
|
axes.set_title("Weights") |
|
axes.set_xlabel("Steps") |
|
axes.set_ylim([0, 1]) |
|
axes.set_xlim([0, num_timesteps - 1]) |
|
for j in np.where(resampled > 0)[0]: |
|
axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0) |
|
fig.canvas.draw() |
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
plots.append(data) |
|
plt.close(fig) |
|
return np.stack(plots, axis=0) |
|
|
|
if resampled is None: |
|
num_timesteps, _, batch_size = weights.get_shape().as_list() |
|
resampled = tf.zeros([num_timesteps], dtype=tf.float32) |
|
plots = tf.py_func(_make_plots, |
|
[tf.nn.softmax(weights, dim=1), |
|
tf.to_float(resampled)], [tf.uint8])[0] |
|
batch_size = weights.get_shape().as_list()[-1] |
|
tf.summary.image( |
|
"weights", plots, batch_size, collections=["infrequent_summaries"]) |
|
|
|
|
|
def summarize_weights(weights, num_timesteps, num_samples): |
|
|
|
weights = tf.convert_to_tensor(weights) |
|
mean = tf.reduce_mean(weights, axis=1, keepdims=True) |
|
squared_diff = tf.square(weights - mean) |
|
variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1) |
|
|
|
variances = tf.reduce_mean(variances, axis=1) |
|
avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2]) |
|
for t in xrange(num_timesteps): |
|
tf.summary.scalar("weights/variance_%d" % t, variances[t]) |
|
tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t]) |
|
tf.summary.histogram("weights/step_%d" % t, weights[t]) |
|
|
|
|
|
def summarize_learning_signal(rewards, tag): |
|
num_resampling_events, _ = rewards.get_shape().as_list() |
|
mean = tf.reduce_mean(rewards, axis=1) |
|
avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1) |
|
reward_square = tf.reduce_mean(tf.square(rewards), axis=1) |
|
for t in xrange(num_resampling_events): |
|
tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t]) |
|
tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t]) |
|
tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t]) |
|
tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t]) |
|
|
|
|
|
def summarize_qs(model, observation, states): |
|
model.q.summarize_weights() |
|
if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")): |
|
states = [tf.zeros_like(states[0])] + states[:-1] |
|
for t, prev_state in enumerate(states): |
|
p = model.p.posterior(observation, prev_state, t) |
|
q = model.q.q_zt(observation, prev_state, t) |
|
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q)) |
|
tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl)) |
|
mean_diff = q.loc - p.loc |
|
mean_abs_err = tf.abs(mean_diff) |
|
mean_rel_err = tf.abs(mean_diff / p.loc) |
|
tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t, |
|
tf.reduce_mean(mean_abs_err)) |
|
tf.summary.scalar("q_mean_convergence/relative_error_%d" % t, |
|
tf.reduce_mean(mean_rel_err)) |
|
sigma_diff = tf.square(q.scale) - tf.square(p.scale) |
|
sigma_abs_err = tf.abs(sigma_diff) |
|
sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale)) |
|
tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t, |
|
tf.reduce_mean(sigma_abs_err)) |
|
tf.summary.scalar("q_variance_convergence/relative_error_%d" % t, |
|
tf.reduce_mean(sigma_rel_err)) |
|
|
|
|
|
def summarize_rs(model, states): |
|
model.r.summarize_weights() |
|
for t, state in enumerate(states): |
|
true_r = model.p.lookahead(state, t) |
|
r = model.r.r_xn(state, t) |
|
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r)) |
|
tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl)) |
|
mean_diff = true_r.loc - r.loc |
|
mean_abs_err = tf.abs(mean_diff) |
|
mean_rel_err = tf.abs(mean_diff / true_r.loc) |
|
tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t, |
|
tf.reduce_mean(mean_abs_err)) |
|
tf.summary.scalar("r_mean_convergence/relative_error_%d" % t, |
|
tf.reduce_mean(mean_rel_err)) |
|
sigma_diff = tf.square(r.scale) - tf.square(true_r.scale) |
|
sigma_abs_err = tf.abs(sigma_diff) |
|
sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale)) |
|
tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t, |
|
tf.reduce_mean(sigma_abs_err)) |
|
tf.summary.scalar("r_variance_convergence/relative_error_%d" % t, |
|
tf.reduce_mean(sigma_rel_err)) |
|
|
|
|
|
def summarize_model(model, true_bs, observation, states, bound, summarize_r=True): |
|
if hasattr(model.p, "bs"): |
|
model_b = tf.reduce_sum(model.p.bs, axis=0) |
|
true_b = tf.reduce_sum(true_bs, axis=0) |
|
abs_err = tf.abs(model_b - true_b) |
|
rel_err = abs_err / true_b |
|
tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b)) |
|
tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b)) |
|
tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err)) |
|
tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err)) |
|
|
|
|
|
|
|
|
|
|
|
def summarize_grads(grads, loss_name): |
|
grad_ema = tf.train.ExponentialMovingAverage(decay=0.99) |
|
vectorized_grads = tf.concat( |
|
[tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0) |
|
new_second_moments = tf.square(vectorized_grads) |
|
new_first_moments = vectorized_grads |
|
maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments]) |
|
first_moments = grad_ema.average(new_first_moments) |
|
second_moments = grad_ema.average(new_second_moments) |
|
variances = second_moments - tf.square(first_moments) |
|
tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances)) |
|
tf.summary.histogram("grad_variance/%s" % loss_name, variances) |
|
tf.summary.histogram("grad_mean/%s" % loss_name, first_moments) |
|
return maintain_grad_ema_op |
|
|