# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utils for plotting and summarizing. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt import numpy as np import scipy import tensorflow as tf import models def summarize_ess(weights, only_last_timestep=False): """Plots the effective sample size. Args: weights: List of length num_timesteps Tensors of shape [num_samples, batch_size] """ num_timesteps = len(weights) batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64) for i in range(num_timesteps): if only_last_timestep and i < num_timesteps-1: continue w = tf.nn.softmax(weights[i], dim=0) centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True) variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1) ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0)) tf.summary.scalar("ess/%d" % i, ess) tf.summary.scalar("ese/%d" % i, ess / batch_size) tf.summary.scalar("weight_variance/%d" % i, variance) def summarize_particles(states, weights, observation, model): """Plots particle locations and weights. Args: states: List of length num_timesteps Tensors of shape [batch_size*num_particles, state_size]. weights: List of length num_timesteps Tensors of shape [num_samples, batch_size] observation: Tensor of shape [batch_size*num_samples, state_size] """ num_timesteps = len(weights) num_samples, batch_size = weights[0].get_shape().as_list() # get q0 information for plotting q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0) q0_loc = q0_dist.loc[0:batch_size, 0] q0_scale = q0_dist.scale[0:batch_size, 0] # get posterior information for plotting post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance, tf.reduce_sum(model.p.bs), model.p.num_timesteps) # Reshape states and weights to be [time, num_samples, batch_size] states = tf.stack(states) weights = tf.stack(weights) # normalize the weights over the sample dimension weights = tf.nn.softmax(weights, dim=1) states = tf.reshape(states, tf.shape(weights)) ess = 1./tf.reduce_sum(tf.square(weights), axis=1) def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post): """ states: [time, num_samples, batch_size] weights [time, num_samples, batch_size] observation: [batch_size, 1] q0: ([batch_size], [batch_size]) post: ... """ num_timesteps, _, batch_size = states_batch.shape plots = [] for i in range(batch_size): states = states_batch[:,:,i] weights = weights_batch[:,:,i] observation = observation_batch[i] ess = ess_batch[:,i] q0_loc = q0[0][i] q0_scale = q0[1][i] fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2)) # Each timestep gets two plots -- a bar plot and a histogram of state locs. # The bar plot will be bar_rows rows tall. # The histogram will be 1 row tall. # There is also 1 extra plot at the top showing the posterior and q. bar_rows = 8 num_rows = (num_timesteps + 1) * (bar_rows + 1) gs = gridspec.GridSpec(num_rows, 1) # Figure out how wide to make the plot prior_lims = (post[1] * -2, post[1] * 2) q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale), scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale)) state_width = states.max() - states.min() state_lims = (states.min() - state_width * 0.15, states.max() + state_width * 0.15) lims = (min(prior_lims[0], q_lims[0], state_lims[0]), max(prior_lims[1], q_lims[1], state_lims[1])) # plot the posterior z0 = np.arange(lims[0], lims[1], 0.1) alpha, pos_mu, sigma_sq, B, T = post neg_mu = -pos_mu scale = np.sqrt((T + 1) * sigma_sq) p_zn = ( alpha * scipy.stats.norm.pdf( observation, loc=pos_mu + B, scale=scale) + (1 - alpha) * scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale)) p_z0 = ( alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq)) + (1 - alpha) * scipy.stats.norm.pdf( z0, loc=neg_mu, scale=np.sqrt(sigma_sq))) p_zn_given_z0 = scipy.stats.norm.pdf( observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq)) post_z0 = (p_z0 * p_zn_given_z0) / p_zn # plot q q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale) ax = plt.subplot(gs[0:bar_rows, :]) ax.plot(z0, q_z0, color="blue") ax.plot(z0, post_z0, color="green") ax.plot(z0, p_z0, color="red") ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10}) ax.set_xticks([]) ax.set_xlim(*lims) # plot the states for t in range(num_timesteps): start = (t + 1) * (bar_rows + 1) ax1 = plt.subplot(gs[start:start + bar_rows, :]) ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :]) # plot the states barplot # ax1.hist( # states[t, :], # weights=weights[t, :], # bins=50, # edgecolor="none", # alpha=0.2) ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none") ax1.set_ylabel("t=%d" % t) ax1.set_xticks([]) ax1.grid(True, which="both") ax1.set_xlim(*lims) # plot the observation ax1.axvline(x=observation, color="red", linestyle="dashed") # add the ESS ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t], ha='center', va='center', transform=ax1.transAxes) # plot the state location histogram ax2.hist2d( states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys") ax2.grid(False) ax2.set_yticks([]) ax2.set_xlim(*lims) if t != num_timesteps - 1: ax2.set_xticks([]) fig.canvas.draw() p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,))) plt.close(fig) return np.stack(plots) plots = tf.py_func(_plot_states, [states, weights, observation, ess, (q0_loc, q0_scale), post], [tf.uint8])[0] tf.summary.image("states", plots, 5, collections=["infrequent_summaries"]) def plot_weights(weights, resampled=None): """Plots the weights and effective sample size from an SMC rollout. Args: weights: [num_timesteps, num_samples, batch_size] importance weights resampled: [num_timesteps] 0/1 indicating if resampling ocurred """ weights = tf.convert_to_tensor(weights) def _make_plots(weights, resampled): num_timesteps, num_samples, batch_size = weights.shape plots = [] for i in range(batch_size): fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4)) axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i])) axes.set_title("Weights") axes.set_xlabel("Steps") axes.set_ylim([0, 1]) axes.set_xlim([0, num_timesteps - 1]) for j in np.where(resampled > 0)[0]: axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0) fig.canvas.draw() data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plots.append(data) plt.close(fig) return np.stack(plots, axis=0) if resampled is None: num_timesteps, _, batch_size = weights.get_shape().as_list() resampled = tf.zeros([num_timesteps], dtype=tf.float32) plots = tf.py_func(_make_plots, [tf.nn.softmax(weights, dim=1), tf.to_float(resampled)], [tf.uint8])[0] batch_size = weights.get_shape().as_list()[-1] tf.summary.image( "weights", plots, batch_size, collections=["infrequent_summaries"]) def summarize_weights(weights, num_timesteps, num_samples): # weights is [num_timesteps, num_samples, batch_size] weights = tf.convert_to_tensor(weights) mean = tf.reduce_mean(weights, axis=1, keepdims=True) squared_diff = tf.square(weights - mean) variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1) # average the variance over the batch variances = tf.reduce_mean(variances, axis=1) avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2]) for t in xrange(num_timesteps): tf.summary.scalar("weights/variance_%d" % t, variances[t]) tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t]) tf.summary.histogram("weights/step_%d" % t, weights[t]) def summarize_learning_signal(rewards, tag): num_resampling_events, _ = rewards.get_shape().as_list() mean = tf.reduce_mean(rewards, axis=1) avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1) reward_square = tf.reduce_mean(tf.square(rewards), axis=1) for t in xrange(num_resampling_events): tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t]) tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t]) tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t]) tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t]) def summarize_qs(model, observation, states): model.q.summarize_weights() if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")): states = [tf.zeros_like(states[0])] + states[:-1] for t, prev_state in enumerate(states): p = model.p.posterior(observation, prev_state, t) q = model.q.q_zt(observation, prev_state, t) kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q)) tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl)) mean_diff = q.loc - p.loc mean_abs_err = tf.abs(mean_diff) mean_rel_err = tf.abs(mean_diff / p.loc) tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t, tf.reduce_mean(mean_abs_err)) tf.summary.scalar("q_mean_convergence/relative_error_%d" % t, tf.reduce_mean(mean_rel_err)) sigma_diff = tf.square(q.scale) - tf.square(p.scale) sigma_abs_err = tf.abs(sigma_diff) sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale)) tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t, tf.reduce_mean(sigma_abs_err)) tf.summary.scalar("q_variance_convergence/relative_error_%d" % t, tf.reduce_mean(sigma_rel_err)) def summarize_rs(model, states): model.r.summarize_weights() for t, state in enumerate(states): true_r = model.p.lookahead(state, t) r = model.r.r_xn(state, t) kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r)) tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl)) mean_diff = true_r.loc - r.loc mean_abs_err = tf.abs(mean_diff) mean_rel_err = tf.abs(mean_diff / true_r.loc) tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t, tf.reduce_mean(mean_abs_err)) tf.summary.scalar("r_mean_convergence/relative_error_%d" % t, tf.reduce_mean(mean_rel_err)) sigma_diff = tf.square(r.scale) - tf.square(true_r.scale) sigma_abs_err = tf.abs(sigma_diff) sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale)) tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t, tf.reduce_mean(sigma_abs_err)) tf.summary.scalar("r_variance_convergence/relative_error_%d" % t, tf.reduce_mean(sigma_rel_err)) def summarize_model(model, true_bs, observation, states, bound, summarize_r=True): if hasattr(model.p, "bs"): model_b = tf.reduce_sum(model.p.bs, axis=0) true_b = tf.reduce_sum(true_bs, axis=0) abs_err = tf.abs(model_b - true_b) rel_err = abs_err / true_b tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b)) tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b)) tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err)) tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err)) #summarize_qs(model, observation, states) #if bound == "fivo-aux" and summarize_r: # summarize_rs(model, states) def summarize_grads(grads, loss_name): grad_ema = tf.train.ExponentialMovingAverage(decay=0.99) vectorized_grads = tf.concat( [tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0) new_second_moments = tf.square(vectorized_grads) new_first_moments = vectorized_grads maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments]) first_moments = grad_ema.average(new_first_moments) second_moments = grad_ema.average(new_second_moments) variances = second_moments - tf.square(first_moments) tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances)) tf.summary.histogram("grad_variance/%s" % loss_name, variances) tf.summary.histogram("grad_mean/%s" % loss_name, first_moments) return maintain_grad_ema_op