""" |
LFADS - Latent Factor Analysis via Dynamical Systems. |
LFADS is an unsupervised method to decompose time series data into |
various factors, such as an initial condition, a generative |
dynamical system, control inputs to that generator, and a low |
dimensional description of the observed data, called the factors. |
Additionally, the observations have a noise model (in this case |
Poisson), so a denoised version of the observations is also created |
(e.g. underlying rates of a Poisson distribution given the observed |
event counts). |
The main data structure being passed around is a dataset. This is a dictionary |
of data dictionaries. |
DATASET: The top level dictionary is simply name (string -> dictionary). |
The nested dictionary is the DATA DICTIONARY, which has the following keys: |
'train_data' and 'valid_data', whose values are the corresponding training |
and validation data with shape |
ExTxD, E - # examples, T - # time steps, D - # dimensions in data. |
The data dictionary also has a few more keys: |
'train_ext_input' and 'valid_ext_input', if there are know external inputs |
to the system being modeled, these take on dimensions: |
ExTxI, E - # examples, T - # time steps, I = # dimensions in input. |
'alignment_matrix_cxf' - If you are using multiple days data, it's possible |
that one can align the channels (see manuscript). If so each dataset will |
contain this matrix, which will be used for both the input adapter and the |
output adapter for each dataset. These matrices, if provided, must be of |
size [data_dim x factors] where data_dim is the number of neurons recorded |
on that day, and factors is chosen and set through the '--factors' flag. |
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to |
the offset for the alignment transformation. It will *subtract* off the |
bias from the data, so pca style inits can align factors across sessions. |
If one runs LFADS on data where the true rates are known for some trials, |
(say simulated, testing data, as in the example shipped with the paper), then |
one can add three more fields for plotting purposes. These are 'train_truth' |
and 'valid_truth', and 'conversion_factor'. These have the same dimensions as |
'train_data', and 'valid_data' but represent the underlying rates of the |
observations. Finally, if one needs to convert scale for plotting the true |
underlying firing rates, there is the 'conversion_factor' key. |
""" |
from __future__ import absolute_import |
from __future__ import division |
from __future__ import print_function |
import numpy as np |
import os |
import tensorflow as tf |
from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput |
from distributions import diag_gaussian_log_likelihood |
from distributions import KLCost_GaussianGaussian, Poisson |
from distributions import LearnableAutoRegressive1Prior |
from distributions import KLCost_GaussianGaussianProcessSampled |
from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data |
from utils import log_sum_exp, flatten |
from plot_lfads import plot_lfads |
class GRU(object): |
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). |
""" |
def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0, |
clip_value=np.inf, collections=None): |
"""Create a GRU object. |
Args: |
num_units: Number of units in the GRU |
forget_bias (optional): Hack to help learning. |
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with |
ws being the weight scale. |
clip_value (optional): if the recurrent values grow above this value, |
clip them. |
collections (optional): List of additonal collections variables should |
belong to. |
""" |
self._num_units = num_units |
self._forget_bias = forget_bias |
self._weight_scale = weight_scale |
self._clip_value = clip_value |
self._collections = collections |
@property |
def state_size(self): |
return self._num_units |
@property |
def output_size(self): |
return self._num_units |
@property |
def state_multiplier(self): |
return 1 |
def output_from_state(self, state): |
"""Return the output portion of the state.""" |
return state |
def __call__(self, inputs, state, scope=None): |
"""Gated recurrent unit (GRU) function. |
Args: |
inputs: A 2D batch x input_dim tensor of inputs. |
state: The previous state from the last time step. |
scope (optional): TF variable scope for defined GRU variables. |
Returns: |
A tuple (state, state), where state is the newly computed state at time t. |
It is returned twice to respect an interface that works for LSTMs. |
""" |
x = inputs |
h = state |
if inputs is not None: |
xh = tf.concat(axis=1, values=[x, h]) |
else: |
xh = h |
with tf.variable_scope(scope or type(self).__name__): |
with tf.variable_scope("Gates"): |
r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh, |
2 * self._num_units, |
alpha=self._weight_scale, |
name="xh_2_ru", |
collections=self._collections)) |
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias) |
with tf.variable_scope("Candidate"): |
xrh = tf.concat(axis=1, values=[x, r * h]) |
c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c", |
collections=self._collections)) |
new_h = u * h + (1 - u) * c |
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value) |
return new_h, new_h |
class GenGRU(object): |
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). |
This version is specialized for the generator, but isn't as fast, so |
we have two. Note this allows for l2 regularization on the recurrent |
weights, but also implicitly rescales the inputs via the 1/sqrt(input) |
scaling in the linear helper routine to be large magnitude, if there are |
fewer inputs than recurrent state. |
""" |
def __init__(self, num_units, forget_bias=1.0, |
input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf, |
input_collections=None, recurrent_collections=None): |
"""Create a GRU object. |
Args: |
num_units: Number of units in the GRU |
forget_bias (optional): Hack to help learning. |
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with |
ws being the weight scale. |
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs), |
with ws being the weight scale. |
clip_value (optional): if the recurrent values grow above this value, |
clip them. |
input_collections (optional): List of additonal collections variables |
that input->rec weights should belong to. |
recurrent_collections (optional): List of additonal collections variables |
that rec->rec weights should belong to. |
""" |
self._num_units = num_units |
self._forget_bias = forget_bias |
self._input_weight_scale = input_weight_scale |
self._rec_weight_scale = rec_weight_scale |
self._clip_value = clip_value |
self._input_collections = input_collections |
self._rec_collections = recurrent_collections |
@property |
def state_size(self): |
return self._num_units |
@property |
def output_size(self): |
return self._num_units |
@property |
def state_multiplier(self): |
return 1 |
def output_from_state(self, state): |
"""Return the output portion of the state.""" |
return state |
def __call__(self, inputs, state, scope=None): |
"""Gated recurrent unit (GRU) function. |
Args: |
inputs: A 2D batch x input_dim tensor of inputs. |
state: The previous state from the last time step. |
scope (optional): TF variable scope for defined GRU variables. |
Returns: |
A tuple (state, state), where state is the newly computed state at time t. |
It is returned twice to respect an interface that works for LSTMs. |
""" |
x = inputs |
h = state |
with tf.variable_scope(scope or type(self).__name__): |
with tf.variable_scope("Gates"): |
r_x = u_x = 0.0 |
if x is not None: |
r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x, |
2 * self._num_units, |
alpha=self._input_weight_scale, |
do_bias=False, |
name="x_2_ru", |
normalized=False, |
collections=self._input_collections)) |
r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h, |
2 * self._num_units, |
do_bias=True, |
alpha=self._rec_weight_scale, |
name="h_2_ru", |
collections=self._rec_collections)) |
r = r_x + r_h |
u = u_x + u_h |
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias) |
with tf.variable_scope("Candidate"): |
c_x = 0.0 |
if x is not None: |
c_x = linear(x, self._num_units, name="x_2_c", do_bias=False, |
alpha=self._input_weight_scale, |
normalized=False, |
collections=self._input_collections) |
c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True, |
alpha=self._rec_weight_scale, |
collections=self._rec_collections) |
c = tf.tanh(c_x + c_rh) |
new_h = u * h + (1 - u) * c |
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value) |
return new_h, new_h |
class LFADS(object): |
"""LFADS - Latent Factor Analysis via Dynamical Systems. |
LFADS is an unsupervised method to decompose time series data into |
various factors, such as an initial condition, a generative |
dynamical system, inferred inputs to that generator, and a low |
dimensional description of the observed data, called the factors. |
Additoinally, the observations have a noise model (in this case |
Poisson), so a denoised version of the observations is also created |
(e.g. underlying rates of a Poisson distribution given the observed |
event counts). |
""" |
def __init__(self, hps, kind="train", datasets=None): |
"""Create an LFADS model. |
train - a model for training, sampling of posteriors is used |
posterior_sample_and_average - sample from the posterior, this is used |
for evaluating the expected value of the outputs of LFADS, given a |
specific input, by averaging over multiple samples from the approx |
posterior. Also used for the lower bound on the negative |
log-likelihood using IWAE error (Importance Weighed Auto-encoder). |
This is the denoising operation. |
prior_sample - a model for generation - sampling from priors is used |
Args: |
hps: The dictionary of hyper parameters. |
kind: the type of model to build (see above). |
datasets: a dictionary of named data_dictionaries, see top of lfads.py |
""" |
print("Building graph...") |
all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean', |
'prior_sample'] |
assert kind in all_kinds, 'Wrong kind' |
if hps.feedback_factors_or_rates == "rates": |
assert len(hps.dataset_names) == 1, \ |
"Multiple datasets not supported for rate feedback." |
num_steps = hps.num_steps |
ic_dim = hps.ic_dim |
co_dim = hps.co_dim |
ext_input_dim = hps.ext_input_dim |
cell_class = GRU |
gen_cell_class = GenGRU |
def makelambda(v): |
return lambda: v |
self.dataName = tf.placeholder(tf.string, shape=()) |
if hps.output_dist == 'poisson': |
assert np.issubdtype( |
datasets[hps.dataset_names[0]]['train_data'].dtype, int), \ |
"Data dtype must be int for poisson output distribution" |
data_dtype = tf.int32 |
elif hps.output_dist == 'gaussian': |
assert np.issubdtype( |
datasets[hps.dataset_names[0]]['train_data'].dtype, float), \ |
"Data dtype must be float for gaussian output dsitribution" |
data_dtype = tf.float32 |
else: |
assert False, "NIY" |
self.dataset_ph = dataset_ph = tf.placeholder(data_dtype, |
[None, num_steps, None], |
name="data") |
self.train_step = tf.get_variable("global_step", [], tf.int64, |
tf.zeros_initializer(), |
trainable=False) |
self.hps = hps |
ndatasets = hps.ndatasets |
factors_dim = hps.factors_dim |
self.preds = preds = [None] * ndatasets |
self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets |
self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets |
self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets |
self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets |
self.datasetNames = dataset_names = hps.dataset_names |
self.ext_inputs = ext_inputs = None |
if len(dataset_names) == 1: |
if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys(): |
used_in_factors_dim = factors_dim |
in_identity_if_poss = False |
else: |
used_in_factors_dim = hps.dataset_dims[dataset_names[0]] |
in_identity_if_poss = True |
else: |
used_in_factors_dim = factors_dim |
in_identity_if_poss = False |
for d, name in enumerate(dataset_names): |
data_dim = hps.dataset_dims[name] |
in_mat_cxf = None |
in_bias_1xf = None |
align_bias_1xc = None |
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): |
dataset = datasets[name] |
if hps.do_train_readin: |
print("Initializing trainable readin matrix with alignment matrix" \ |
" provided for dataset:", name) |
else: |
print("Setting non-trainable readin matrix to alignment matrix" \ |
" provided for dataset:", name) |
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) |
if in_mat_cxf.shape != (data_dim, factors_dim): |
raise ValueError("""Alignment matrix must have dimensions %d x %d |
(data_dim x factors_dim), but currently has %d x %d."""% |
(data_dim, factors_dim, in_mat_cxf.shape[0], |
in_mat_cxf.shape[1])) |
if datasets and 'alignment_bias_c' in datasets[name].keys(): |
dataset = datasets[name] |
if hps.do_train_readin: |
print("Initializing trainable readin bias with alignment bias " \ |
"provided for dataset:", name) |
else: |
print("Setting non-trainable readin bias to alignment bias " \ |
"provided for dataset:", name) |
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) |
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) |
if align_bias_1xc.shape[1] != data_dim: |
raise ValueError("""Alignment bias must have dimensions %d |
(data_dim), but currently has %d."""% |
(data_dim, in_mat_cxf.shape[0])) |
if in_mat_cxf is not None and align_bias_1xc is not None: |
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf) |
if hps.do_train_readin: |
collections_readin=['IO_transformations'] |
else: |
collections_readin=None |
in_fac_lin = init_linear(data_dim, used_in_factors_dim, |
do_bias=True, |
mat_init_value=in_mat_cxf, |
bias_init_value=in_bias_1xf, |
identity_if_possible=in_identity_if_poss, |
normalized=False, name="x_2_infac_"+name, |
collections=collections_readin, |
trainable=hps.do_train_readin) |
in_fac_W, in_fac_b = in_fac_lin |
fns_in_fac_Ws[d] = makelambda(in_fac_W) |
fns_in_fac_bs[d] = makelambda(in_fac_b) |
with tf.variable_scope("glm"): |
out_identity_if_poss = False |
if len(dataset_names) == 1 and \ |
factors_dim == hps.dataset_dims[dataset_names[0]]: |
out_identity_if_poss = True |
for d, name in enumerate(dataset_names): |
data_dim = hps.dataset_dims[name] |
in_mat_cxf = None |
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): |
dataset = datasets[name] |
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) |
if datasets and 'alignment_bias_c' in datasets[name].keys(): |
dataset = datasets[name] |
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) |
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) |
out_mat_fxc = None |
out_bias_1xc = None |
if in_mat_cxf is not None: |
out_mat_fxc = in_mat_cxf.T |
if align_bias_1xc is not None: |
out_bias_1xc = align_bias_1xc |
if hps.output_dist == 'poisson': |
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True, |
mat_init_value=out_mat_fxc, |
bias_init_value=out_bias_1xc, |
identity_if_possible=out_identity_if_poss, |
normalized=False, |
name="fac_2_logrates_"+name, |
collections=['IO_transformations']) |
out_fac_W, out_fac_b = out_fac_lin |
elif hps.output_dist == 'gaussian': |
out_fac_lin_mean = \ |
init_linear(factors_dim, data_dim, do_bias=True, |
mat_init_value=out_mat_fxc, |
bias_init_value=out_bias_1xc, |
normalized=False, |
name="fac_2_means_"+name, |
collections=['IO_transformations']) |
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean |
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32) |
bias_init_value = np.ones([1, data_dim]).astype(np.float32) |
out_fac_lin_logvar = \ |
init_linear(factors_dim, data_dim, do_bias=True, |
mat_init_value=mat_init_value, |
bias_init_value=bias_init_value, |
normalized=False, |
name="fac_2_logvars_"+name, |
collections=['IO_transformations']) |
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean |
out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar |
out_fac_W = tf.concat( |
axis=1, values=[out_fac_W_mean, out_fac_W_logvar]) |
out_fac_b = tf.concat( |
axis=1, values=[out_fac_b_mean, out_fac_b_logvar]) |
else: |
assert False, "NIY" |
preds[d] = tf.equal(tf.constant(name), self.dataName) |
data_dim = hps.dataset_dims[name] |
fns_out_fac_Ws[d] = makelambda(out_fac_W) |
fns_out_fac_bs[d] = makelambda(out_fac_b) |
pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws) |
pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs) |
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) |
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) |
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True) |
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True) |
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True) |
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True) |
if hps.ext_input_dim > 0: |
self.ext_input = tf.placeholder(tf.float32, |
[None, num_steps, ext_input_dim], |
name="ext_input") |
else: |
self.ext_input = None |
ext_input_bxtxi = self.ext_input |
self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob") |
self.batch_size = batch_size = int(hps.batch_size) |
self.learning_rate = tf.Variable(float(hps.learning_rate_init), |
trainable=False, name="learning_rate") |
self.learning_rate_decay_op = self.learning_rate.assign( |
self.learning_rate * hps.learning_rate_decay_factor) |
dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob) |
if hps.ext_input_dim > 0: |
ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob) |
else: |
ext_input_do_bxtxi = None |
def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse, |
num_steps_to_encode): |
"""Encode data for LFADS |
Args: |
dataset_bxtxd - the data to encode, as a 3 tensor, with dims |
time x batch x data dims. |
enc_cell: encoder cell |
name: name of encoder |
forward_or_reverse: string, encode in forward or reverse direction |
num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode |
Returns: |
encoded data as a list with num_steps_to_encode items, in order |
""" |
if forward_or_reverse == "forward": |
dstr = "_fwd" |
time_fwd_or_rev = range(num_steps_to_encode) |
else: |
dstr = "_rev" |
time_fwd_or_rev = reversed(range(num_steps_to_encode)) |
with tf.variable_scope(name+"_enc"+dstr, reuse=False): |
enc_state = tf.tile( |
tf.Variable(tf.zeros([1, enc_cell.state_size]), |
name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1])) |
enc_state.set_shape([None, enc_cell.state_size]) |
enc_outs = [None] * num_steps_to_encode |
for i, t in enumerate(time_fwd_or_rev): |
with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None): |
dataset_t_bxd = dataset_bxtxd[:,t,:] |
in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b |
in_fac_t_bxf.set_shape([None, used_in_factors_dim]) |
if ext_input_dim > 0 and not hps.inject_ext_input_to_gen: |
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:] |
enc_input_t_bxfpe = tf.concat( |
axis=1, values=[in_fac_t_bxf, ext_input_t_bxi]) |
else: |
enc_input_t_bxfpe = in_fac_t_bxf |
enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state) |
enc_outs[t] = enc_out |
return enc_outs |
self.ic_enc_fwd = [None] * num_steps |
self.ic_enc_rev = [None] * num_steps |
if ic_dim > 0: |
enc_ic_cell = cell_class(hps.ic_enc_dim, |
weight_scale=hps.cell_weight_scale, |
clip_value=hps.cell_clip_value) |
ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell, |
"ic", "forward", |
hps.num_steps_for_gen_ic) |
ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell, |
"ic", "reverse", |
hps.num_steps_for_gen_ic) |
self.ic_enc_fwd = ic_enc_fwd |
self.ic_enc_rev = ic_enc_rev |
self.ci_enc_fwd = [None] * num_steps |
self.ci_enc_rev = [None] * num_steps |
if co_dim > 0: |
enc_ci_cell = cell_class(hps.ci_enc_dim, |
weight_scale=hps.cell_weight_scale, |
clip_value=hps.cell_clip_value) |
ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell, |
"ci", "forward", |
hps.num_steps) |
if hps.do_causal_controller: |
ci_enc_rev = None |
else: |
ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell, |
"ci", "reverse", |
hps.num_steps) |
self.ci_enc_fwd = ci_enc_fwd |
self.ci_enc_rev = ci_enc_rev |
with tf.variable_scope("z", reuse=False): |
self.prior_zs_g0 = None |
self.posterior_zs_g0 = None |
self.g0s_val = None |
if ic_dim > 0: |
self.prior_zs_g0 = \ |
LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0", |
mean_init=0.0, |
var_min=hps.ic_prior_var_min, |
var_init=hps.ic_prior_var_scale, |
var_max=hps.ic_prior_var_max) |
ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]]) |
ic_enc = tf.nn.dropout(ic_enc, keep_prob) |
self.posterior_zs_g0 = \ |
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0", |
var_min=hps.ic_post_var_min) |
if kind in ["train", "posterior_sample_and_average", |
"posterior_push_mean"]: |
zs_g0 = self.posterior_zs_g0 |
else: |
zs_g0 = self.prior_zs_g0 |
if kind in ["train", "posterior_sample_and_average", "prior_sample"]: |
self.g0s_val = zs_g0.sample |
else: |
self.g0s_val = zs_g0.mean |
self.prior_zs_co = prior_zs_co = [None] * num_steps |
self.posterior_zs_co = posterior_zs_co = [None] * num_steps |
self.zs_co = zs_co = [None] * num_steps |
self.prior_zs_ar_con = None |
if co_dim > 0: |
autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)] |
noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)] |
self.prior_zs_ar_con = prior_zs_ar_con = \ |
LearnableAutoRegressive1Prior(batch_size, hps.co_dim, |
autocorrelation_taus, |
noise_variances, |
hps.do_train_prior_ar_atau, |
hps.do_train_prior_ar_nvar, |
num_steps, "u_prior_ar1") |
self.controller_outputs = u_t = [None] * num_steps |
self.con_ics = con_state = None |
self.con_states = con_states = [None] * num_steps |
self.con_outs = con_outs = [None] * num_steps |
self.gen_inputs = gen_inputs = [None] * num_steps |
if co_dim > 0: |
con_cell = gen_cell_class(hps.con_dim, |
input_weight_scale=hps.cell_weight_scale, |
rec_weight_scale=hps.cell_weight_scale, |
clip_value=hps.cell_clip_value, |
recurrent_collections=['l2_con_reg']) |
with tf.variable_scope("con", reuse=False): |
self.con_ics = tf.tile( |
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]), |
name="c0"), |
tf.stack([batch_size, 1])) |
self.con_ics.set_shape([None, con_cell.state_size]) |
con_states[-1] = self.con_ics |
gen_cell = gen_cell_class(hps.gen_dim, |
input_weight_scale=hps.gen_cell_input_weight_scale, |
rec_weight_scale=hps.gen_cell_rec_weight_scale, |
clip_value=hps.cell_clip_value, |
recurrent_collections=['l2_gen_reg']) |
with tf.variable_scope("gen", reuse=False): |
if ic_dim == 0: |
self.gen_ics = tf.tile( |
tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"), |
tf.stack([batch_size, 1])) |
else: |
self.gen_ics = linear(self.g0s_val, gen_cell.state_size, |
identity_if_possible=True, |
name="g0_2_gen_ic") |
self.gen_states = gen_states = [None] * num_steps |
self.gen_outs = gen_outs = [None] * num_steps |
gen_states[-1] = self.gen_ics |
gen_outs[-1] = gen_cell.output_from_state(gen_states[-1]) |
self.factors = factors = [None] * num_steps |
factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False, |
normalized=True, name="gen_2_fac") |
self.rates = rates = [None] * num_steps |
with tf.variable_scope("glm", reuse=False): |
if hps.output_dist == 'poisson': |
log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b |
log_rates_t0.set_shape([None, None]) |
rates[-1] = tf.exp(log_rates_t0) |
rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]]) |
elif hps.output_dist == 'gaussian': |
mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b |
mean_n_logvars.set_shape([None, None]) |
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2, |
value=mean_n_logvars) |
rates[-1] = means_t_bxd |
else: |
assert False, "NIY" |
self.output_dist_params = dist_params = [None] * num_steps |
self.log_p_xgz_b = log_p_xgz_b = 0.0 |
for t in range(num_steps): |
if co_dim > 0: |
tlag = t - hps.controller_input_lag |
if tlag < 0: |
con_in_f_t = tf.zeros_like(ci_enc_fwd[0]) |
else: |
con_in_f_t = ci_enc_fwd[tlag] |
if hps.do_causal_controller: |
con_in_list_t = [con_in_f_t] |
else: |
tlag_rev = t + hps.controller_input_lag |
if tlag_rev >= num_steps: |
con_in_r_t = tf.zeros_like(ci_enc_rev[0]) |
else: |
con_in_r_t = ci_enc_rev[tlag_rev] |
con_in_list_t = [con_in_f_t, con_in_r_t] |
if hps.do_feed_factors_to_controller: |
if hps.feedback_factors_or_rates == "factors": |
con_in_list_t.append(factors[t-1]) |
elif hps.feedback_factors_or_rates == "rates": |
con_in_list_t.append(rates[t-1]) |
else: |
assert False, "NIY" |
con_in_t = tf.concat(axis=1, values=con_in_list_t) |
con_in_t = tf.nn.dropout(con_in_t, keep_prob) |
with tf.variable_scope("con", reuse=True if t > 0 else None): |
con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1]) |
posterior_zs_co[t] = \ |
DiagonalGaussianFromInput(con_outs[t], co_dim, |
name="con_to_post_co") |
if kind == "train": |
u_t[t] = posterior_zs_co[t].sample |
elif kind == "posterior_sample_and_average": |
u_t[t] = posterior_zs_co[t].sample |
elif kind == "posterior_push_mean": |
u_t[t] = posterior_zs_co[t].mean |
else: |
u_t[t] = prior_zs_ar_con.samples_t[t] |
if ext_input_dim > 0 and hps.inject_ext_input_to_gen: |
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:] |
if co_dim > 0: |
gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi]) |
else: |
gen_inputs[t] = ext_input_t_bxi |
else: |
gen_inputs[t] = u_t[t] |
data_t_bxd = dataset_ph[:,t,:] |
with tf.variable_scope("gen", reuse=True if t > 0 else None): |
gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1]) |
gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob) |
with tf.variable_scope("gen", reuse=True): |
factors[t] = linear(gen_outs[t], factors_dim, do_bias=False, |
normalized=True, name="gen_2_fac") |
with tf.variable_scope("glm", reuse=True if t > 0 else None): |
if hps.output_dist == 'poisson': |
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b |
log_rates_t.set_shape([None, None]) |
rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) |
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]]) |
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd) |
elif hps.output_dist == 'gaussian': |
mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b |
mean_n_logvars.set_shape([None, None]) |
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2, |
value=mean_n_logvars) |
rates[t] = means_t_bxd |
dist_params[t] = tf.concat( |
axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))]) |
loglikelihood_t = \ |
diag_gaussian_log_likelihood(data_t_bxd, |
means_t_bxd, logvars_t_bxd) |
else: |
assert False, "NIY" |
log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1]) |
self.corr_cost = tf.constant(0.0) |
if hps.co_mean_corr_scale > 0.0: |
all_sum_corr = [] |
for i in range(hps.co_dim): |
for j in range(i+1, hps.co_dim): |
sum_corr_ij = tf.constant(0.0) |
for t in range(num_steps): |
u_mean_t = posterior_zs_co[t].mean |
sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j] |
all_sum_corr.append(0.5 * tf.square(sum_corr_ij)) |
self.corr_cost = tf.reduce_mean(all_sum_corr) |
kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32) |
kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32) |
self.kl_cost = tf.constant(0.0) |
self.recon_cost = tf.constant(0.0) |
self.nll_bound_vae = tf.constant(0.0) |
self.nll_bound_iwae = tf.constant(0.0) |
if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]: |
kl_cost_g0_b = 0.0 |
kl_cost_co_b = 0.0 |
if ic_dim > 0: |
g0_priors = [self.prior_zs_g0] |
g0_posts = [self.posterior_zs_g0] |
kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b |
kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b |
if co_dim > 0: |
kl_cost_co_b = \ |
KLCost_GaussianGaussianProcessSampled( |
posterior_zs_co, prior_zs_ar_con).kl_cost_b |
kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b |
self.recon_cost = - tf.reduce_mean(log_p_xgz_b) |
self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b) |
lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b |
self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b) |
k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32) |
iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b) |
self.nll_bound_iwae = -iwae_lb_on_ll |
self.l2_cost = tf.constant(0.0) |
if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0: |
l2_costs = [] |
l2_numels = [] |
l2_reg_var_lists = [tf.get_collection('l2_gen_reg'), |
tf.get_collection('l2_con_reg')] |
l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale] |
for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales): |
for v in l2_reg_vars: |
numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v))) |
numel_f = tf.cast(numel, tf.float32) |
l2_numels.append(numel_f) |
v_l2 = tf.reduce_sum(v*v) |
l2_costs.append(0.5 * l2_scale * v_l2) |
self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels) |
self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0) |
self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0) |
kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32) |
l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32) |
kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32) |
l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32) |
self.kl_weight = kl_weight = \ |
tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0) |
self.l2_weight = l2_weight = \ |
tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0) |
self.timed_kl_cost = kl_weight * self.kl_cost |
self.timed_l2_cost = l2_weight * self.l2_cost |
self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost |
self.cost = self.recon_cost + self.timed_kl_cost + \ |
self.timed_l2_cost + self.weight_corr_cost |
if kind != "train": |
self.seso_saver = tf.train.Saver(tf.global_variables(), |
max_to_keep=hps.max_ckpt_to_keep) |
self.lve_saver = tf.train.Saver(tf.global_variables(), |
max_to_keep=hps.max_ckpt_to_keep_lve) |
return |
if self.hps.do_train_io_only: |
self.train_vars = tvars = \ |
tf.get_collection('IO_transformations', |
scope=tf.get_variable_scope().name) |
elif self.hps.do_train_encoder_only: |
tvars1 = \ |
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, |
scope='LFADS/ic_enc_*') |
tvars2 = \ |
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, |
scope='LFADS/z/ic_enc_*') |
self.train_vars = tvars = tvars1 + tvars2 |
else: |
self.train_vars = tvars = \ |
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, |
scope=tf.get_variable_scope().name) |
print("done.") |
print("Model Variables (to be optimized): ") |
total_params = 0 |
for i in range(len(tvars)): |
shape = tvars[i].get_shape().as_list() |
print(" ", i, tvars[i].name, shape) |
total_params += np.prod(shape) |
print("Total model parameters: ", total_params) |
grads = tf.gradients(self.cost, tvars) |
grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm) |
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999, |
epsilon=1e-01) |
self.grads = grads |
self.grad_global_norm = grad_global_norm |
self.train_op = opt.apply_gradients( |
zip(grads, tvars), global_step=self.train_step) |
self.seso_saver = tf.train.Saver(tf.global_variables(), |
max_to_keep=hps.max_ckpt_to_keep) |
self.lve_saver = tf.train.Saver(tf.global_variables(), |
max_to_keep=hps.max_ckpt_to_keep) |
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3], |
name='image_tensor') |
self.example_summ = tf.summary.image("LFADS example", self.example_image, |
collections=["example_summaries"]) |
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate) |
self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight) |
self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight) |
self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost) |
self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm", |
self.grad_global_norm) |
if hps.co_dim > 0: |
self.atau_summ = [None] * hps.co_dim |
self.pvar_summ = [None] * hps.co_dim |
for c in range(hps.co_dim): |
self.atau_summ[c] = \ |
tf.summary.scalar("AR Autocorrelation taus " + str(c), |
tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c])) |
self.pvar_summ[c] = \ |
tf.summary.scalar("AR Variances " + str(c), |
tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c])) |
kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph') |
self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph, |
collections=["train_summaries"]) |
self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph, |
collections=["valid_summaries"]) |
l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph') |
self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph, |
collections=["train_summaries"]) |
recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph') |
self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)", |
recon_cost_ph, |
collections=["train_summaries"]) |
self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)", |
recon_cost_ph, |
collections=["valid_summaries"]) |
total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph') |
self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph, |
collections=["train_summaries"]) |
self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph, |
collections=["valid_summaries"]) |
self.kl_cost_ph = kl_cost_ph |
self.l2_cost_ph = l2_cost_ph |
self.recon_cost_ph = recon_cost_ph |
self.total_cost_ph = total_cost_ph |
self.merged_examples = tf.summary.merge_all(key="example_summaries") |
self.merged_generic = tf.summary.merge_all() |
self.merged_train = tf.summary.merge_all(key="train_summaries") |
self.merged_valid = tf.summary.merge_all(key="valid_summaries") |
session = tf.get_default_session() |
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log") |
self.writer = tf.summary.FileWriter(self.logfile) |
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None, |
keep_prob=None): |
"""Build the feed dictionary, handles cases where there is no value defined. |
Args: |
train_name: The key into the datasets, to set the tf.case statement for |
the proper readin / readout matrices. |
data_bxtxd: The data tensor |
ext_input_bxtxi (optional): The external input tensor |
keep_prob: The drop out keep probability. |
Returns: |
The feed dictionary with TF tensors as keys and data as values, for use |
with tf.Session.run() |
""" |
feed_dict = {} |
B, T, _ = data_bxtxd.shape |
feed_dict[self.dataName] = train_name |
feed_dict[self.dataset_ph] = data_bxtxd |
if self.ext_input is not None and ext_input_bxtxi is not None: |
feed_dict[self.ext_input] = ext_input_bxtxi |
if keep_prob is None: |
feed_dict[self.keep_prob] = self.hps.keep_prob |
else: |
feed_dict[self.keep_prob] = keep_prob |
return feed_dict |
@staticmethod |
def get_batch(data_extxd, ext_input_extxi=None, batch_size=None, |
example_idxs=None): |
"""Get a batch of data, either randomly chosen, or specified directly. |
Args: |
data_extxd: The data to model, numpy tensors with shape: |
# examples x # time steps x # dimensions |
ext_input_extxi (optional): The external inputs, numpy tensor with shape: |
# examples x # time steps x # external input dimensions |
batch_size: The size of the batch to return |
example_idxs (optional): The example indices used to select examples. |
Returns: |
A tuple with two parts: |
1. Batched data numpy tensor with shape: |
batch_size x # time steps x # dimensions |
2. Batched external input numpy tensor with shape: |
batch_size x # time steps x # external input dims |
""" |
assert batch_size is not None or example_idxs is not None, "Problems" |
E, T, D = data_extxd.shape |
if example_idxs is None: |
example_idxs = np.random.choice(E, batch_size) |
ext_input_bxtxi = None |
if ext_input_extxi is not None: |
ext_input_bxtxi = ext_input_extxi[example_idxs,:,:] |
return data_extxd[example_idxs,:,:], ext_input_bxtxi |
@staticmethod |
def example_idxs_mod_batch_size(nexamples, batch_size): |
"""Given a number of examples, E, and a batch_size, B, generate indices |
[0, 1, 2, ... B-1; |
[B, B+1, ... 2*B-1; |
... |
] |
returning those indices as a 2-dim tensor shaped like E/B x B. Note that |
shape is only correct if E % B == 0. If not, then an extra row is generated |
so that the remainder of examples is included. The extra examples are |
explicitly to to the zero index (see randomize_example_idxs_mod_batch_size) |
for randomized behavior. |
Args: |
nexamples: The number of examples to batch up. |
batch_size: The size of the batch. |
Returns: |
2-dim tensor as described above. |
""" |
bmrem = batch_size - (nexamples % batch_size) |
bmrem_examples = [] |
if bmrem < batch_size: |
ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32) |
bmrem_examples = np.sort(ridxs) |
example_idxs = range(nexamples) + list(bmrem_examples) |
example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size]) |
return example_idxs_e_x_edivb, bmrem |
@staticmethod |
def randomize_example_idxs_mod_batch_size(nexamples, batch_size): |
"""Indices 1:nexamples, randomized, in 2D form of |
shape = (nexamples / batch_size) x batch_size. The remainder |
is managed by drawing randomly from 1:nexamples. |
Args: |
nexamples: number of examples to randomize |
batch_size: number of elements in batch |
Returns: |
The randomized, properly shaped indicies. |
""" |
assert nexamples > batch_size, "Problems" |
bmrem = batch_size - nexamples % batch_size |
bmrem_examples = [] |
if bmrem < batch_size: |
bmrem_examples = np.random.choice(range(nexamples), |
size=bmrem, replace=False) |
example_idxs = range(nexamples) + list(bmrem_examples) |
mixed_example_idxs = np.random.permutation(example_idxs) |
example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size]) |
return example_idxs_e_x_edivb, bmrem |
def shuffle_spikes_in_time(self, data_bxtxd): |
"""Shuffle the spikes in the temporal dimension. This is useful to |
help the LFADS system avoid overfitting to individual spikes or fast |
oscillations found in the data that are irrelevant to behavior. A |
pure 'tabula rasa' approach would avoid this, but LFADS is sensitive |
enough to pick up dynamics that you may not want. |
Args: |
data_bxtxd: numpy array of spike count data to be shuffled. |
Returns: |
S_bxtxd, a numpy array with the same dimensions and contents as |
data_bxtxd, but shuffled appropriately. |
""" |
B, T, N = data_bxtxd.shape |
w = self.hps.temporal_spike_jitter_width |
if w == 0: |
return data_bxtxd |
max_counts = np.max(data_bxtxd) |
S_bxtxd = np.zeros([B,T,N]) |
for mc in range(1,max_counts+1): |
idxs = np.nonzero(data_bxtxd >= mc) |
data_ones = np.zeros_like(data_bxtxd) |
data_ones[data_bxtxd >= mc] = 1 |
nfound = len(idxs[0]) |
shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound) |
shuffle_tidxs = idxs[1].copy() |
shuffle_tidxs += shuffles_incrs_in_time |
shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0] |
shuffle_tidxs[shuffle_tidxs > T-1] = \ |
(T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1)) |
for iii in zip(idxs[0], shuffle_tidxs, idxs[2]): |
S_bxtxd[iii] += 1 |
return S_bxtxd |
def shuffle_and_flatten_datasets(self, datasets, kind='train'): |
"""Since LFADS supports multiple datasets in the same dynamical model, |
we have to be careful to use all the data in a single training epoch. But |
since the datasets my have different data dimensionality, we cannot batch |
examples from data dictionaries together. Instead, we generate random |
batches within each data dictionary, and then randomize these batches |
while holding onto the dataname, so that when it's time to feed |
the graph, the correct in/out matrices can be selected, per batch. |
Args: |
datasets: A dict of data dicts. The dataset dict is simply a |
name(string)-> data dictionary mapping (See top of lfads.py). |
kind: 'train' or 'valid' |
Returns: |
A flat list, in which each element is a pair ('name', indices). |
""" |
batch_size = self.hps.batch_size |
ndatasets = len(datasets) |
random_example_idxs = {} |
epoch_idxs = {} |
all_name_example_idx_pairs = [] |
kind_data = kind + '_data' |
for name, data_dict in datasets.items(): |
nexamples, ntime, data_dim = data_dict[kind_data].shape |
epoch_idxs[name] = 0 |
random_example_idxs, _ = \ |
self.randomize_example_idxs_mod_batch_size(nexamples, batch_size) |
epoch_size = random_example_idxs.shape[0] |
names = [name] * epoch_size |
all_name_example_idx_pairs += zip(names, random_example_idxs) |
np.random.shuffle(all_name_example_idx_pairs) |
return all_name_example_idx_pairs |
def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True): |
"""Train the model through the entire dataset once. |
Args: |
datasets: A dict of data dicts. The dataset dict is simply a |
name(string)-> data dictionary mapping (See top of lfads.py). |
batch_size (optional): The batch_size to use |
do_save_ckpt (optional): Should the routine save a checkpoint on this |
training epoch? |
Returns: |
A tuple with 6 float values: |
(total cost of the epoch, epoch reconstruction cost, |
epoch kl cost, KL weight used this training epoch, |
total l2 cost on generator, and the corresponding weight). |
""" |
ops_to_eval = [self.cost, self.recon_cost, |
self.kl_cost, self.kl_weight, |
self.l2_cost, self.l2_weight, |
self.train_op] |
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train") |
total_cost = total_recon_cost = total_kl_cost = 0.0 |
epoch_size = len(collected_op_values) |
for op_values in collected_op_values: |
total_cost += op_values[0] |
total_recon_cost += op_values[1] |
total_kl_cost += op_values[2] |
kl_weight = collected_op_values[-1][3] |
l2_cost = collected_op_values[-1][4] |
l2_weight = collected_op_values[-1][5] |
epoch_total_cost = total_cost / epoch_size |
epoch_recon_cost = total_recon_cost / epoch_size |
epoch_kl_cost = total_kl_cost / epoch_size |
if do_save_ckpt: |
session = tf.get_default_session() |
checkpoint_path = os.path.join(self.hps.lfads_save_dir, |
self.hps.checkpoint_name + '.ckpt') |
self.seso_saver.save(session, checkpoint_path, |
global_step=self.train_step) |
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \ |
kl_weight, l2_cost, l2_weight |
def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None, |
do_collect=True, keep_prob=None): |
"""Run the model through the entire dataset once. |
Args: |
datasets: A dict of data dicts. The dataset dict is simply a |
name(string)-> data dictionary mapping (See top of lfads.py). |
ops_to_eval: A list of tensorflow operations that will be evaluated in |
the tf.session.run() call. |
batch_size (optional): The batch_size to use |
do_collect (optional): Should the routine collect all session.run |
output as a list, and return it? |
keep_prob (optional): The dropout keep probability. |
Returns: |
A list of lists, the internal list is the return for the ops for each |
session.run() call. The outer list collects over the epoch. |
""" |
hps = self.hps |
all_name_example_idx_pairs = \ |
self.shuffle_and_flatten_datasets(datasets, kind) |
kind_data = kind + '_data' |
kind_ext_input = kind + '_ext_input' |
total_cost = total_recon_cost = total_kl_cost = 0.0 |
session = tf.get_default_session() |
epoch_size = len(all_name_example_idx_pairs) |
evaled_ops_list = [] |
for name, example_idxs in all_name_example_idx_pairs: |
data_dict = datasets[name] |
data_extxd = data_dict[kind_data] |
if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0: |
data_extxd = self.shuffle_spikes_in_time(data_extxd) |
ext_input_extxi = data_dict[kind_ext_input] |
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi, |
example_idxs=example_idxs) |
feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi, |
keep_prob=keep_prob) |
evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict) |
if do_collect: |
evaled_ops_list.append(evaled_ops_np) |
return evaled_ops_list |
def summarize_all(self, datasets, summary_values): |
"""Plot and summarize stuff in tensorboard. |
Note that everything done in the current function is otherwise done on |
a single, randomly selected dataset (except for summary_values, which are |
passed in.) |
Args: |
datasets, the dictionary of datasets used in the study. |
summary_values: These summary values are created from the training loop, |
and so summarize the entire set of datasets. |
""" |
hps = self.hps |
tr_kl_cost = summary_values['tr_kl_cost'] |
tr_recon_cost = summary_values['tr_recon_cost'] |
tr_total_cost = summary_values['tr_total_cost'] |
kl_weight = summary_values['kl_weight'] |
l2_weight = summary_values['l2_weight'] |
l2_cost = summary_values['l2_cost'] |
has_any_valid_set = summary_values['has_any_valid_set'] |
i = summary_values['nepochs'] |
session = tf.get_default_session() |
train_summ, train_step = session.run([self.merged_train, |
self.train_step], |
feed_dict={self.l2_cost_ph:l2_cost, |
self.kl_cost_ph:tr_kl_cost, |
self.recon_cost_ph:tr_recon_cost, |
self.total_cost_ph:tr_total_cost}) |
self.writer.add_summary(train_summ, train_step) |
if has_any_valid_set: |
ev_kl_cost = summary_values['ev_kl_cost'] |
ev_recon_cost = summary_values['ev_recon_cost'] |
ev_total_cost = summary_values['ev_total_cost'] |
eval_summ = session.run(self.merged_valid, |
feed_dict={self.kl_cost_ph:ev_kl_cost, |
self.recon_cost_ph:ev_recon_cost, |
self.total_cost_ph:ev_total_cost}) |
self.writer.add_summary(eval_summ, train_step) |
print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\ |
recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\ |
kl weight: %.2f, l2 weight: %.2f" % \ |
(i, train_step, tr_total_cost, ev_total_cost, |
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost, |
l2_cost, kl_weight, l2_weight)) |
csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \ |
recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \ |
klweight,%.2f, l2weight,%.2f\n"% \ |
(i, train_step, tr_total_cost, ev_total_cost, |
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost, |
l2_cost, kl_weight, l2_weight) |
else: |
print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\ |
l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \ |
(i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost, |
l2_cost, kl_weight, l2_weight)) |
csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \ |
l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \ |
(i, train_step, tr_total_cost, tr_recon_cost, |
tr_kl_cost, l2_cost, kl_weight, l2_weight) |
if self.hps.csv_log: |
csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv') |
with open(csv_file, "a") as myfile: |
myfile.write(csv_outstr) |
def plot_single_example(self, datasets): |
"""Plot an image relating to a randomly chosen, specific example. We use |
posterior sample and average by taking one example, and filling a whole |
batch with that example, sample from the posterior, and then average the |
quantities. |
""" |
hps = self.hps |
all_data_names = datasets.keys() |
data_name = np.random.permutation(all_data_names)[0] |
data_dict = datasets[data_name] |
has_valid_set = True if data_dict['valid_data'] is not None else False |
cf = 1.0 |
E, _, _ = data_dict['train_data'].shape |
eidx = np.random.choice(E) |
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32) |
train_data_bxtxd, train_ext_input_bxtxi = \ |
self.get_batch(data_dict['train_data'], data_dict['train_ext_input'], |
example_idxs=example_idxs) |
truth_train_data_bxtxd = None |
if 'train_truth' in data_dict and data_dict['train_truth'] is not None: |
truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'], |
example_idxs=example_idxs) |
cf = data_dict['conversion_factor'] |
train_model_values = self.eval_model_runs_batch(data_name, |
train_data_bxtxd, |
train_ext_input_bxtxi, |
do_average_batch=False) |
train_step = train_model_values['train_steps'] |
feed_dict = self.build_feed_dict(data_name, train_data_bxtxd, |
train_ext_input_bxtxi, keep_prob=1.0) |
session = tf.get_default_session() |
generic_summ = session.run(self.merged_generic, feed_dict=feed_dict) |
self.writer.add_summary(generic_summ, train_step) |
valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None |
truth_valid_data_bxtxd = None |
if has_valid_set: |
E, _, _ = data_dict['valid_data'].shape |
eidx = np.random.choice(E) |
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32) |
valid_data_bxtxd, valid_ext_input_bxtxi = \ |
self.get_batch(data_dict['valid_data'], |
data_dict['valid_ext_input'], |
example_idxs=example_idxs) |
if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None: |
truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'], |
example_idxs=example_idxs) |
else: |
truth_valid_data_bxtxd = None |
valid_model_values = self.eval_model_runs_batch(data_name, |
valid_data_bxtxd, |
valid_ext_input_bxtxi, |
do_average_batch=False) |
example_image = plot_lfads(train_bxtxd=train_data_bxtxd, |
train_model_vals=train_model_values, |
train_ext_input_bxtxi=train_ext_input_bxtxi, |
train_truth_bxtxd=truth_train_data_bxtxd, |
valid_bxtxd=valid_data_bxtxd, |
valid_model_vals=valid_model_values, |
valid_ext_input_bxtxi=valid_ext_input_bxtxi, |
valid_truth_bxtxd=truth_valid_data_bxtxd, |
bidx=None, cf=cf, output_dist=hps.output_dist) |
example_image = np.expand_dims(example_image, axis=0) |
example_summ = session.run(self.merged_examples, |
feed_dict={self.example_image : example_image}) |
self.writer.add_summary(example_summ) |
def train_model(self, datasets): |
"""Train the model, print per-epoch information, and save checkpoints. |
Loop over training epochs. The function that actually does the |
training is train_epoch. This function iterates over the training |
data, one epoch at a time. The learning rate schedule is such |
that it will stay the same until the cost goes up in comparison to |
the last few values, then it will drop. |
Args: |
datasets: A dict of data dicts. The dataset dict is simply a |
name(string)-> data dictionary mapping (See top of lfads.py). |
""" |
hps = self.hps |
has_any_valid_set = False |
for data_dict in datasets.values(): |
if data_dict['valid_data'] is not None: |
has_any_valid_set = True |
break |
session = tf.get_default_session() |
lr = session.run(self.learning_rate) |
lr_stop = hps.learning_rate_stop |
i = -1 |
train_costs = [] |
valid_costs = [] |
ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0 |
lowest_ev_cost = np.Inf |
while True: |
i += 1 |
do_save_ckpt = True if i % 10 ==0 else False |
tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \ |
self.train_epoch(datasets, do_save_ckpt=do_save_ckpt) |
if has_any_valid_set: |
ev_total_cost, ev_recon_cost, ev_kl_cost = \ |
self.eval_cost_epoch(datasets, kind='valid') |
valid_costs.append(ev_total_cost) |
n_lve = 1 |
run_avg_lve = np.mean(valid_costs[-n_lve:]) |
if kl_weight >= 1.0 and \ |
(l2_weight >= 1.0 or \ |
(self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \ |
and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost): |
lowest_ev_cost = run_avg_lve |
checkpoint_path = os.path.join(self.hps.lfads_save_dir, |
self.hps.checkpoint_name + '_lve.ckpt') |
self.lve_saver.save(session, checkpoint_path, |
global_step=self.train_step, |
latest_filename='checkpoint_lve') |
values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set, |
'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost, |
'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost, |
'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost, |
'l2_weight':l2_weight, 'kl_weight':kl_weight, |
'l2_cost':l2_cost} |
self.summarize_all(datasets, values) |
self.plot_single_example(datasets) |
train_res = tr_total_cost |
n_lr = hps.learning_rate_n_to_compare |
if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]): |
_ = session.run(self.learning_rate_decay_op) |
lr = session.run(self.learning_rate) |
print(" Decreasing learning rate to %f." % lr) |
train_costs.append(np.inf) |
else: |
train_costs.append(train_res) |
if lr < lr_stop: |
print("Stopping optimization based on learning rate criteria.") |
break |
def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None, |
batch_size=None): |
"""Evaluate the cost of the epoch. |
Args: |
data_dict: The dictionary of data (training and validation) used for |
training and evaluation of the model, respectively. |
Returns: |
a 3 tuple of costs: |
(epoch total cost, epoch reconstruction cost, epoch KL cost) |
""" |
ops_to_eval = [self.cost, self.recon_cost, self.kl_cost] |
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind, |
keep_prob=1.0) |
total_cost = total_recon_cost = total_kl_cost = 0.0 |
epoch_size = len(collected_op_values) |
for op_values in collected_op_values: |
total_cost += op_values[0] |
total_recon_cost += op_values[1] |
total_kl_cost += op_values[2] |
epoch_total_cost = total_cost / epoch_size |
epoch_recon_cost = total_recon_cost / epoch_size |
epoch_kl_cost = total_kl_cost / epoch_size |
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost |
def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None, |
do_eval_cost=False, do_average_batch=False): |
"""Returns all the goodies for the entire model, per batch. |
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1 |
in which case this handles the padding and truncating automatically |
Args: |
data_name: The name of the data dict, to select which in/out matrices |
to use. |
data_bxtxd: Numpy array training data with shape: |
batch_size x # time steps x # dimensions |
ext_input_bxtxi: Numpy array training external input with shape: |
batch_size x # time steps x # external input dims |
do_eval_cost (optional): If true, the IWAE (Importance Weighted |
Autoencoder) log likeihood bound, instead of the VAE version. |
do_average_batch (optional): average over the batch, useful for getting |
good IWAE costs, and model outputs for a single data point. |
Returns: |
A dictionary with the outputs of the model decoder, namely: |
prior g0 mean, prior g0 variance, approx. posterior mean, approx |
posterior mean, the generator initial conditions, the control inputs (if |
enabled), the state of the generator, the factors, and the rates. |
""" |
session = tf.get_default_session() |
hps = self.hps |
batch_size = hps.batch_size |
E, _, _ = data_bxtxd.shape |
if E < hps.batch_size: |
data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)), |
mode='constant', constant_values=0) |
if ext_input_bxtxi is not None: |
ext_input_bxtxi = np.pad(ext_input_bxtxi, |
((0, hps.batch_size-E), (0, 0), (0, 0)), |
mode='constant', constant_values=0) |
feed_dict = self.build_feed_dict(data_name, data_bxtxd, |
ext_input_bxtxi, keep_prob=1.0) |
tf_vals = [self.gen_ics, self.gen_states, self.factors, |
self.output_dist_params] |
tf_vals.append(self.cost) |
tf_vals.append(self.nll_bound_vae) |
tf_vals.append(self.nll_bound_iwae) |
tf_vals.append(self.train_step) |
if self.hps.ic_dim > 0: |
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar, |
self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar] |
if self.hps.co_dim > 0: |
tf_vals.append(self.controller_outputs) |
tf_vals_flat, fidxs = flatten(tf_vals) |
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict) |
ff = 0 |
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 |
train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 |
if self.hps.ic_dim > 0: |
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 |
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
if self.hps.co_dim > 0: |
controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
gen_ics = gen_ics[0] |
costs = costs[0] |
nll_bound_vaes = nll_bound_vaes[0] |
nll_bound_iwaes = nll_bound_iwaes[0] |
train_steps = train_steps[0] |
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states) |
factors = list_t_bxn_to_tensor_bxtxn(factors) |
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params) |
if self.hps.ic_dim > 0: |
prior_g0_mean = prior_g0_mean[0] |
prior_g0_logvar = prior_g0_logvar[0] |
post_g0_mean = post_g0_mean[0] |
post_g0_logvar = post_g0_logvar[0] |
if self.hps.co_dim > 0: |
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs) |
if E < hps.batch_size: |
idx = np.arange(E) |
gen_ics = gen_ics[idx, :] |
gen_states = gen_states[idx, :] |
factors = factors[idx, :, :] |
out_dist_params = out_dist_params[idx, :, :] |
if self.hps.ic_dim > 0: |
prior_g0_mean = prior_g0_mean[idx, :] |
prior_g0_logvar = prior_g0_logvar[idx, :] |
post_g0_mean = post_g0_mean[idx, :] |
post_g0_logvar = post_g0_logvar[idx, :] |
if self.hps.co_dim > 0: |
controller_outputs = controller_outputs[idx, :, :] |
if do_average_batch: |
gen_ics = np.mean(gen_ics, axis=0) |
gen_states = np.mean(gen_states, axis=0) |
factors = np.mean(factors, axis=0) |
out_dist_params = np.mean(out_dist_params, axis=0) |
if self.hps.ic_dim > 0: |
prior_g0_mean = np.mean(prior_g0_mean, axis=0) |
prior_g0_logvar = np.mean(prior_g0_logvar, axis=0) |
post_g0_mean = np.mean(post_g0_mean, axis=0) |
post_g0_logvar = np.mean(post_g0_logvar, axis=0) |
if self.hps.co_dim > 0: |
controller_outputs = np.mean(controller_outputs, axis=0) |
model_vals = {} |
model_vals['gen_ics'] = gen_ics |
model_vals['gen_states'] = gen_states |
model_vals['factors'] = factors |
model_vals['output_dist_params'] = out_dist_params |
model_vals['costs'] = costs |
model_vals['nll_bound_vaes'] = nll_bound_vaes |
model_vals['nll_bound_iwaes'] = nll_bound_iwaes |
model_vals['train_steps'] = train_steps |
if self.hps.ic_dim > 0: |
model_vals['prior_g0_mean'] = prior_g0_mean |
model_vals['prior_g0_logvar'] = prior_g0_logvar |
model_vals['post_g0_mean'] = post_g0_mean |
model_vals['post_g0_logvar'] = post_g0_logvar |
if self.hps.co_dim > 0: |
model_vals['controller_outputs'] = controller_outputs |
return model_vals |
def eval_model_runs_avg_epoch(self, data_name, data_extxd, |
ext_input_extxi=None): |
"""Returns all the expected value for goodies for the entire model. |
The expected value is taken over hidden (z) variables, namely the initial |
conditions and the control inputs. The expected value is approximate, and |
accomplished via sampling (batch_size) samples for every examples. |
Args: |
data_name: The name of the data dict, to select which in/out matrices |
to use. |
data_extxd: Numpy array training data with shape: |
# examples x # time steps x # dimensions |
ext_input_extxi (optional): Numpy array training external input with |
shape: # examples x # time steps x # external input dims |
Returns: |
A dictionary with the averaged outputs of the model decoder, namely: |
prior g0 mean, prior g0 variance, approx. posterior mean, approx |
posterior mean, the generator initial conditions, the control inputs (if |
enabled), the state of the generator, the factors, and the output |
distribution parameters, e.g. (rates or mean and variances). |
""" |
hps = self.hps |
batch_size = hps.batch_size |
E, T, D = data_extxd.shape |
E_to_process = hps.ps_nexamples_to_process |
if E_to_process > E: |
E_to_process = E |
if hps.ic_dim > 0: |
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
post_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
if hps.co_dim > 0: |
controller_outputs = np.zeros([E_to_process, T, hps.co_dim]) |
gen_ics = np.zeros([E_to_process, hps.gen_dim]) |
gen_states = np.zeros([E_to_process, T, hps.gen_dim]) |
factors = np.zeros([E_to_process, T, hps.factors_dim]) |
if hps.output_dist == 'poisson': |
out_dist_params = np.zeros([E_to_process, T, D]) |
elif hps.output_dist == 'gaussian': |
out_dist_params = np.zeros([E_to_process, T, D+D]) |
else: |
assert False, "NIY" |
costs = np.zeros(E_to_process) |
nll_bound_vaes = np.zeros(E_to_process) |
nll_bound_iwaes = np.zeros(E_to_process) |
train_steps = np.zeros(E_to_process) |
for es_idx in range(E_to_process): |
print("Running %d of %d." % (es_idx+1, E_to_process)) |
example_idxs = es_idx * np.ones(batch_size, dtype=np.int32) |
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, |
ext_input_extxi, |
batch_size=batch_size, |
example_idxs=example_idxs) |
model_values = self.eval_model_runs_batch(data_name, data_bxtxd, |
ext_input_bxtxi, |
do_eval_cost=True, |
do_average_batch=True) |
if self.hps.ic_dim > 0: |
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean'] |
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar'] |
post_g0_mean[es_idx,:] = model_values['post_g0_mean'] |
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar'] |
gen_ics[es_idx,:] = model_values['gen_ics'] |
if self.hps.co_dim > 0: |
controller_outputs[es_idx,:,:] = model_values['controller_outputs'] |
gen_states[es_idx,:,:] = model_values['gen_states'] |
factors[es_idx,:,:] = model_values['factors'] |
out_dist_params[es_idx,:,:] = model_values['output_dist_params'] |
costs[es_idx] = model_values['costs'] |
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes'] |
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes'] |
train_steps[es_idx] = model_values['train_steps'] |
print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \ |
% (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx])) |
model_runs = {} |
if self.hps.ic_dim > 0: |
model_runs['prior_g0_mean'] = prior_g0_mean |
model_runs['prior_g0_logvar'] = prior_g0_logvar |
model_runs['post_g0_mean'] = post_g0_mean |
model_runs['post_g0_logvar'] = post_g0_logvar |
model_runs['gen_ics'] = gen_ics |
if self.hps.co_dim > 0: |
model_runs['controller_outputs'] = controller_outputs |
model_runs['gen_states'] = gen_states |
model_runs['factors'] = factors |
model_runs['output_dist_params'] = out_dist_params |
model_runs['costs'] = costs |
model_runs['nll_bound_vaes'] = nll_bound_vaes |
model_runs['nll_bound_iwaes'] = nll_bound_iwaes |
model_runs['train_steps'] = train_steps |
return model_runs |
def eval_model_runs_push_mean(self, data_name, data_extxd, |
ext_input_extxi=None): |
"""Returns values of interest for the model by pushing the means through |
The mean values for both initial conditions and the control inputs are |
pushed through the model instead of sampling (as is done in |
eval_model_runs_avg_epoch). |
This is a quick and approximate version of estimating these values instead |
of sampling from the posterior many times and then averaging those values of |
interest. |
Internally, a total of batch_size trials are run through the model at once. |
Args: |
data_name: The name of the data dict, to select which in/out matrices |
to use. |
data_extxd: Numpy array training data with shape: |
# examples x # time steps x # dimensions |
ext_input_extxi (optional): Numpy array training external input with |
shape: # examples x # time steps x # external input dims |
Returns: |
A dictionary with the estimated outputs of the model decoder, namely: |
prior g0 mean, prior g0 variance, approx. posterior mean, approx |
posterior mean, the generator initial conditions, the control inputs (if |
enabled), the state of the generator, the factors, and the output |
distribution parameters, e.g. (rates or mean and variances). |
""" |
hps = self.hps |
batch_size = hps.batch_size |
E, T, D = data_extxd.shape |
E_to_process = hps.ps_nexamples_to_process |
if E_to_process > E: |
print("Setting number of posterior samples to process to : ", E) |
E_to_process = E |
if hps.ic_dim > 0: |
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
post_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
if hps.co_dim > 0: |
controller_outputs = np.zeros([E_to_process, T, hps.co_dim]) |
gen_ics = np.zeros([E_to_process, hps.gen_dim]) |
gen_states = np.zeros([E_to_process, T, hps.gen_dim]) |
factors = np.zeros([E_to_process, T, hps.factors_dim]) |
if hps.output_dist == 'poisson': |
out_dist_params = np.zeros([E_to_process, T, D]) |
elif hps.output_dist == 'gaussian': |
out_dist_params = np.zeros([E_to_process, T, D+D]) |
else: |
assert False, "NIY" |
costs = np.zeros(E_to_process) |
nll_bound_vaes = np.zeros(E_to_process) |
nll_bound_iwaes = np.zeros(E_to_process) |
train_steps = np.zeros(E_to_process) |
def trial_batches(N, per): |
for i in range(0, N, per): |
yield np.arange(i, min(i+per, N), dtype=np.int32) |
for batch_idx, es_idx in enumerate(trial_batches(E_to_process, |
hps.batch_size)): |
print("Running trial batch %d with %d trials" % (batch_idx+1, |
len(es_idx))) |
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, |
ext_input_extxi, |
batch_size=batch_size, |
example_idxs=es_idx) |
model_values = self.eval_model_runs_batch(data_name, data_bxtxd, |
ext_input_bxtxi, |
do_eval_cost=True, |
do_average_batch=False) |
if self.hps.ic_dim > 0: |
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean'] |
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar'] |
post_g0_mean[es_idx,:] = model_values['post_g0_mean'] |
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar'] |
gen_ics[es_idx,:] = model_values['gen_ics'] |
if self.hps.co_dim > 0: |
controller_outputs[es_idx,:,:] = model_values['controller_outputs'] |
gen_states[es_idx,:,:] = model_values['gen_states'] |
factors[es_idx,:,:] = model_values['factors'] |
out_dist_params[es_idx,:,:] = model_values['output_dist_params'] |
costs[es_idx] = model_values['costs'] |
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes'] |
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes'] |
train_steps[es_idx] = model_values['train_steps'] |
model_runs = {} |
if self.hps.ic_dim > 0: |
model_runs['prior_g0_mean'] = prior_g0_mean |
model_runs['prior_g0_logvar'] = prior_g0_logvar |
model_runs['post_g0_mean'] = post_g0_mean |
model_runs['post_g0_logvar'] = post_g0_logvar |
model_runs['gen_ics'] = gen_ics |
if self.hps.co_dim > 0: |
model_runs['controller_outputs'] = controller_outputs |
model_runs['gen_states'] = gen_states |
model_runs['factors'] = factors |
model_runs['output_dist_params'] = out_dist_params |
model_runs['costs'] = costs |
model_runs['nll_bound_vaes'] = nll_bound_vaes |
model_runs['nll_bound_iwaes'] = nll_bound_iwaes |
model_runs['train_steps'] = train_steps |
return model_runs |
def write_model_runs(self, datasets, output_fname=None, push_mean=False): |
"""Run the model on the data in data_dict, and save the computed values. |
LFADS generates a number of outputs for each examples, and these are all |
saved. They are: |
The mean and variance of the prior of g0. |
The mean and variance of approximate posterior of g0. |
The control inputs (if enabled) |
The initial conditions, g0, for all examples. |
The generator states for all time. |
The factors for all time. |
The output distribution parameters (e.g. rates) for all time. |
Args: |
datasets: a dictionary of named data_dictionaries, see top of lfads.py |
output_fname: a file name stem for the output files. |
push_mean: if False (default), generates batch_size samples for each trial |
and averages the results. if True, runs each trial once without noise, |
pushing the posterior mean initial conditions and control inputs through |
the trained model. False is used for posterior_sample_and_average, True |
is used for posterior_push_mean. |
""" |
hps = self.hps |
kind = hps.kind |
for data_name, data_dict in datasets.items(): |
data_tuple = [('train', data_dict['train_data'], |
data_dict['train_ext_input']), |
('valid', data_dict['valid_data'], |
data_dict['valid_ext_input'])] |
for data_kind, data_extxd, ext_input_extxi in data_tuple: |
if not output_fname: |
fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind |
else: |
fname = output_fname + data_name + '_' + data_kind + '_' + kind |
print("Writing data for %s data and kind %s." % (data_name, data_kind)) |
if push_mean: |
model_runs = self.eval_model_runs_push_mean(data_name, data_extxd, |
ext_input_extxi) |
else: |
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd, |
ext_input_extxi) |
full_fname = os.path.join(hps.lfads_save_dir, fname) |
write_data(full_fname, model_runs, compression='gzip') |
print("Done.") |
def write_model_samples(self, dataset_name, output_fname=None): |
"""Use the prior distribution to generate batch_size number of samples |
from the model. |
LFADS generates a number of outputs for each sample, and these are all |
saved. They are: |
The mean and variance of the prior of g0. |
The control inputs (if enabled) |
The initial conditions, g0, for all examples. |
The generator states for all time. |
The factors for all time. |
The output distribution parameters (e.g. rates) for all time. |
Args: |
dataset_name: The name of the dataset to grab the factors -> rates |
alignment matrices from. |
output_fname: The name of the file in which to save the generated |
samples. |
""" |
hps = self.hps |
batch_size = hps.batch_size |
print("Generating %d samples" % (batch_size)) |
tf_vals = [self.factors, self.gen_states, self.gen_ics, |
self.cost, self.output_dist_params] |
if hps.ic_dim > 0: |
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar] |
if hps.co_dim > 0: |
tf_vals += [self.prior_zs_ar_con.samples_t] |
tf_vals_flat, fidxs = flatten(tf_vals) |
session = tf.get_default_session() |
feed_dict = {} |
feed_dict[self.dataName] = dataset_name |
feed_dict[self.keep_prob] = 1.0 |
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict) |
ff = 0 |
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
if hps.ic_dim > 0: |
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
if hps.co_dim > 0: |
prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
gen_ics = gen_ics[0] |
costs = costs[0] |
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states) |
factors = list_t_bxn_to_tensor_bxtxn(factors) |
output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params) |
if hps.ic_dim > 0: |
prior_g0_mean = prior_g0_mean[0] |
prior_g0_logvar = prior_g0_logvar[0] |
if hps.co_dim > 0: |
prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con) |
model_vals = {} |
model_vals['gen_ics'] = gen_ics |
model_vals['gen_states'] = gen_states |
model_vals['factors'] = factors |
model_vals['output_dist_params'] = output_dist_params |
model_vals['costs'] = costs.reshape(1) |
if hps.ic_dim > 0: |
model_vals['prior_g0_mean'] = prior_g0_mean |
model_vals['prior_g0_logvar'] = prior_g0_logvar |
if hps.co_dim > 0: |
model_vals['prior_zs_ar_con'] = prior_zs_ar_con |
full_fname = os.path.join(hps.lfads_save_dir, output_fname) |
write_data(full_fname, model_vals, compression='gzip') |
print("Done.") |
@staticmethod |
def eval_model_parameters(use_nested=True, include_strs=None): |
"""Evaluate and return all of the TF variables in the model. |
Args: |
use_nested (optional): For returning values, use a nested dictoinary, based |
on variable scoping, or return all variables in a flat dictionary. |
include_strs (optional): A list of strings to use as a filter, to reduce the |
number of variables returned. A variable name must contain at least one |
string in include_strs as a sub-string in order to be returned. |
Returns: |
The parameters of the model. This can be in a flat |
dictionary, or a nested dictionary, where the nesting is by variable |
scope. |
""" |
all_tf_vars = tf.global_variables() |
session = tf.get_default_session() |
all_tf_vars_eval = session.run(all_tf_vars) |
vars_dict = {} |
strs = ["LFADS"] |
if include_strs: |
strs += include_strs |
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)): |
if any(s in include_strs for s in var.name): |
if not isinstance(var_eval, np.ndarray): |
print(var.name, """ is not numpy array, saving as numpy array |
with value: """, var_eval, type(var_eval)) |
e = np.array(var_eval) |
print(e, type(e)) |
else: |
e = var_eval |
vars_dict[var.name] = e |
if not use_nested: |
return vars_dict |
var_names = vars_dict.keys() |
nested_vars_dict = {} |
current_dict = nested_vars_dict |
for v, var_name in enumerate(var_names): |
var_split_name_list = var_name.split('/') |
split_name_list_len = len(var_split_name_list) |
current_dict = nested_vars_dict |
for p, part in enumerate(var_split_name_list): |
if p < split_name_list_len - 1: |
if part in current_dict: |
current_dict = current_dict[part] |
else: |
current_dict[part] = {} |
current_dict = current_dict[part] |
else: |
current_dict[part] = vars_dict[var_name] |
return nested_vars_dict |
@staticmethod |
def spikify_rates(rates_bxtxd): |
"""Randomly spikify underlying rates according a Poisson distribution |
Args: |
rates_bxtxd: a numpy tensor with shape: |
Returns: |
A numpy array with the same shape as rates_bxtxd, but with the event |
counts. |
""" |
B,T,N = rates_bxtxd.shape |
assert all([B > 0, N > 0]), "problems" |
spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32) |
for b in range(B): |
for t in range(T): |
for n in range(N): |
rate = rates_bxtxd[b,t,n] |
count = np.random.poisson(rate) |
spikes_bxtxd[b,t,n] = count |
return spikes_bxtxd |