|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import matplotlib |
|
matplotlib.use('Agg') |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
def _plot_item(W, name, full_name, nspaces): |
|
plt.figure() |
|
if W.shape == (): |
|
print(name, ": ", W) |
|
elif W.shape[0] == 1: |
|
plt.stem(W.T) |
|
plt.title(full_name) |
|
elif W.shape[1] == 1: |
|
plt.stem(W) |
|
plt.title(full_name) |
|
else: |
|
plt.imshow(np.abs(W), interpolation='nearest', cmap='jet'); |
|
plt.colorbar() |
|
plt.title(full_name) |
|
|
|
|
|
def all_plot(d, full_name="", exclude="", nspaces=0): |
|
"""Recursively plot all the LFADS model parameters in the nested |
|
dictionary.""" |
|
for k, v in d.iteritems(): |
|
this_name = full_name+"/"+k |
|
if isinstance(v, dict): |
|
all_plot(v, full_name=this_name, exclude=exclude, nspaces=nspaces+4) |
|
else: |
|
if exclude == "" or exclude not in this_name: |
|
_plot_item(v, name=k, full_name=full_name+"/"+k, nspaces=nspaces+4) |
|
|
|
|
|
|
|
def plot_time_series(vals_bxtxn, bidx=None, n_to_plot=np.inf, scale=1.0, |
|
color='r', title=None): |
|
|
|
if bidx is None: |
|
vals_txn = np.mean(vals_bxtxn, axis=0) |
|
else: |
|
vals_txn = vals_bxtxn[bidx,:,:] |
|
|
|
T, N = vals_txn.shape |
|
if n_to_plot > N: |
|
n_to_plot = N |
|
|
|
plt.plot(vals_txn[:,0:n_to_plot] + scale*np.array(range(n_to_plot)), |
|
color=color, lw=1.0) |
|
plt.axis('tight') |
|
if title: |
|
plt.title(title) |
|
|
|
|
|
def plot_lfads_timeseries(data_bxtxn, model_vals, ext_input_bxtxi=None, |
|
truth_bxtxn=None, bidx=None, output_dist="poisson", |
|
conversion_factor=1.0, subplot_cidx=0, |
|
col_title=None): |
|
|
|
n_to_plot = 10 |
|
scale = 1.0 |
|
nrows = 7 |
|
plt.subplot(nrows,2,1+subplot_cidx) |
|
|
|
if output_dist == 'poisson': |
|
rates = means = conversion_factor * model_vals['output_dist_params'] |
|
plot_time_series(rates, bidx, n_to_plot=n_to_plot, scale=scale, |
|
title=col_title + " rates (LFADS - red, Truth - black)") |
|
elif output_dist == 'gaussian': |
|
means_vars = model_vals['output_dist_params'] |
|
means, vars = np.split(means_vars,2, axis=2) |
|
stds = np.sqrt(vars) |
|
plot_time_series(means, bidx, n_to_plot=n_to_plot, scale=scale, |
|
title=col_title + " means (LFADS - red, Truth - black)") |
|
plot_time_series(means+stds, bidx, n_to_plot=n_to_plot, scale=scale, |
|
color='c') |
|
plot_time_series(means-stds, bidx, n_to_plot=n_to_plot, scale=scale, |
|
color='c') |
|
else: |
|
assert 'NIY' |
|
|
|
|
|
if truth_bxtxn is not None: |
|
plot_time_series(truth_bxtxn, bidx, n_to_plot=n_to_plot, color='k', |
|
scale=scale) |
|
|
|
input_title = "" |
|
if "controller_outputs" in model_vals.keys(): |
|
input_title += " Controller Output" |
|
plt.subplot(nrows,2,3+subplot_cidx) |
|
u_t = model_vals['controller_outputs'][0:-1] |
|
plot_time_series(u_t, bidx, n_to_plot=n_to_plot, color='c', scale=1.0, |
|
title=col_title + input_title) |
|
|
|
if ext_input_bxtxi is not None: |
|
input_title += " External Input" |
|
plot_time_series(ext_input_bxtxi, n_to_plot=n_to_plot, color='b', |
|
scale=scale, title=col_title + input_title) |
|
|
|
plt.subplot(nrows,2,5+subplot_cidx) |
|
plot_time_series(means, bidx, |
|
n_to_plot=n_to_plot, scale=1.0, |
|
title=col_title + " Spikes (LFADS - red, Spikes - black)") |
|
plot_time_series(data_bxtxn, bidx, n_to_plot=n_to_plot, color='k', scale=1.0) |
|
|
|
plt.subplot(nrows,2,7+subplot_cidx) |
|
plot_time_series(model_vals['factors'], bidx, n_to_plot=n_to_plot, color='b', |
|
scale=2.0, title=col_title + " Factors") |
|
|
|
plt.subplot(nrows,2,9+subplot_cidx) |
|
plot_time_series(model_vals['gen_states'], bidx, n_to_plot=n_to_plot, |
|
color='g', scale=1.0, title=col_title + " Generator State") |
|
|
|
if bidx is not None: |
|
data_nxt = data_bxtxn[bidx,:,:].T |
|
params_nxt = model_vals['output_dist_params'][bidx,:,:].T |
|
else: |
|
data_nxt = np.mean(data_bxtxn, axis=0).T |
|
params_nxt = np.mean(model_vals['output_dist_params'], axis=0).T |
|
if output_dist == 'poisson': |
|
means_nxt = params_nxt |
|
elif output_dist == 'gaussian': |
|
means_nxt = np.vsplit(params_nxt,2)[0] |
|
else: |
|
assert "NIY" |
|
|
|
plt.subplot(nrows,2,11+subplot_cidx) |
|
plt.imshow(data_nxt, aspect='auto', interpolation='nearest') |
|
plt.title(col_title + ' Data') |
|
|
|
plt.subplot(nrows,2,13+subplot_cidx) |
|
plt.imshow(means_nxt, aspect='auto', interpolation='nearest') |
|
plt.title(col_title + ' Means') |
|
|
|
|
|
def plot_lfads(train_bxtxd, train_model_vals, |
|
train_ext_input_bxtxi=None, train_truth_bxtxd=None, |
|
valid_bxtxd=None, valid_model_vals=None, |
|
valid_ext_input_bxtxi=None, valid_truth_bxtxd=None, |
|
bidx=None, cf=1.0, output_dist='poisson'): |
|
|
|
|
|
f = plt.figure(figsize=(18,20), tight_layout=True) |
|
plot_lfads_timeseries(train_bxtxd, train_model_vals, |
|
train_ext_input_bxtxi, |
|
truth_bxtxn=train_truth_bxtxd, |
|
conversion_factor=cf, bidx=bidx, |
|
output_dist=output_dist, col_title='Train') |
|
plot_lfads_timeseries(valid_bxtxd, valid_model_vals, |
|
valid_ext_input_bxtxi, |
|
truth_bxtxn=valid_truth_bxtxd, |
|
conversion_factor=cf, bidx=bidx, |
|
output_dist=output_dist, |
|
subplot_cidx=1, col_title='Valid') |
|
|
|
|
|
f.canvas.draw() |
|
data = np.fromstring(f.canvas.tostring_rgb(), dtype=np.uint8, sep='') |
|
data_wxhx3 = data.reshape(f.canvas.get_width_height()[::-1] + (3,)) |
|
plt.close() |
|
|
|
return data_wxhx3 |
|
|