|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
LFADS - Latent Factor Analysis via Dynamical Systems. |
|
|
|
LFADS is an unsupervised method to decompose time series data into |
|
various factors, such as an initial condition, a generative |
|
dynamical system, control inputs to that generator, and a low |
|
dimensional description of the observed data, called the factors. |
|
Additionally, the observations have a noise model (in this case |
|
Poisson), so a denoised version of the observations is also created |
|
(e.g. underlying rates of a Poisson distribution given the observed |
|
event counts). |
|
|
|
The main data structure being passed around is a dataset. This is a dictionary |
|
of data dictionaries. |
|
|
|
DATASET: The top level dictionary is simply name (string -> dictionary). |
|
The nested dictionary is the DATA DICTIONARY, which has the following keys: |
|
'train_data' and 'valid_data', whose values are the corresponding training |
|
and validation data with shape |
|
ExTxD, E - # examples, T - # time steps, D - # dimensions in data. |
|
The data dictionary also has a few more keys: |
|
'train_ext_input' and 'valid_ext_input', if there are know external inputs |
|
to the system being modeled, these take on dimensions: |
|
ExTxI, E - # examples, T - # time steps, I = # dimensions in input. |
|
'alignment_matrix_cxf' - If you are using multiple days data, it's possible |
|
that one can align the channels (see manuscript). If so each dataset will |
|
contain this matrix, which will be used for both the input adapter and the |
|
output adapter for each dataset. These matrices, if provided, must be of |
|
size [data_dim x factors] where data_dim is the number of neurons recorded |
|
on that day, and factors is chosen and set through the '--factors' flag. |
|
'alignment_bias_c' - See alignment_matrix_cxf. This bias will used to |
|
the offset for the alignment transformation. It will *subtract* off the |
|
bias from the data, so pca style inits can align factors across sessions. |
|
|
|
|
|
If one runs LFADS on data where the true rates are known for some trials, |
|
(say simulated, testing data, as in the example shipped with the paper), then |
|
one can add three more fields for plotting purposes. These are 'train_truth' |
|
and 'valid_truth', and 'conversion_factor'. These have the same dimensions as |
|
'train_data', and 'valid_data' but represent the underlying rates of the |
|
observations. Finally, if one needs to convert scale for plotting the true |
|
underlying firing rates, there is the 'conversion_factor' key. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
|
|
import numpy as np |
|
import os |
|
import tensorflow as tf |
|
from distributions import LearnableDiagonalGaussian, DiagonalGaussianFromInput |
|
from distributions import diag_gaussian_log_likelihood |
|
from distributions import KLCost_GaussianGaussian, Poisson |
|
from distributions import LearnableAutoRegressive1Prior |
|
from distributions import KLCost_GaussianGaussianProcessSampled |
|
|
|
from utils import init_linear, linear, list_t_bxn_to_tensor_bxtxn, write_data |
|
from utils import log_sum_exp, flatten |
|
from plot_lfads import plot_lfads |
|
|
|
|
|
class GRU(object): |
|
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). |
|
|
|
""" |
|
def __init__(self, num_units, forget_bias=1.0, weight_scale=1.0, |
|
clip_value=np.inf, collections=None): |
|
"""Create a GRU object. |
|
|
|
Args: |
|
num_units: Number of units in the GRU |
|
forget_bias (optional): Hack to help learning. |
|
weight_scale (optional): weights are scaled by ws/sqrt(#inputs), with |
|
ws being the weight scale. |
|
clip_value (optional): if the recurrent values grow above this value, |
|
clip them. |
|
collections (optional): List of additonal collections variables should |
|
belong to. |
|
""" |
|
self._num_units = num_units |
|
self._forget_bias = forget_bias |
|
self._weight_scale = weight_scale |
|
self._clip_value = clip_value |
|
self._collections = collections |
|
|
|
@property |
|
def state_size(self): |
|
return self._num_units |
|
|
|
@property |
|
def output_size(self): |
|
return self._num_units |
|
|
|
@property |
|
def state_multiplier(self): |
|
return 1 |
|
|
|
def output_from_state(self, state): |
|
"""Return the output portion of the state.""" |
|
return state |
|
|
|
def __call__(self, inputs, state, scope=None): |
|
"""Gated recurrent unit (GRU) function. |
|
|
|
Args: |
|
inputs: A 2D batch x input_dim tensor of inputs. |
|
state: The previous state from the last time step. |
|
scope (optional): TF variable scope for defined GRU variables. |
|
|
|
Returns: |
|
A tuple (state, state), where state is the newly computed state at time t. |
|
It is returned twice to respect an interface that works for LSTMs. |
|
""" |
|
|
|
x = inputs |
|
h = state |
|
if inputs is not None: |
|
xh = tf.concat(axis=1, values=[x, h]) |
|
else: |
|
xh = h |
|
|
|
with tf.variable_scope(scope or type(self).__name__): |
|
with tf.variable_scope("Gates"): |
|
|
|
r, u = tf.split(axis=1, num_or_size_splits=2, value=linear(xh, |
|
2 * self._num_units, |
|
alpha=self._weight_scale, |
|
name="xh_2_ru", |
|
collections=self._collections)) |
|
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias) |
|
with tf.variable_scope("Candidate"): |
|
xrh = tf.concat(axis=1, values=[x, r * h]) |
|
c = tf.tanh(linear(xrh, self._num_units, name="xrh_2_c", |
|
collections=self._collections)) |
|
new_h = u * h + (1 - u) * c |
|
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value) |
|
|
|
return new_h, new_h |
|
|
|
|
|
class GenGRU(object): |
|
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). |
|
|
|
This version is specialized for the generator, but isn't as fast, so |
|
we have two. Note this allows for l2 regularization on the recurrent |
|
weights, but also implicitly rescales the inputs via the 1/sqrt(input) |
|
scaling in the linear helper routine to be large magnitude, if there are |
|
fewer inputs than recurrent state. |
|
|
|
""" |
|
def __init__(self, num_units, forget_bias=1.0, |
|
input_weight_scale=1.0, rec_weight_scale=1.0, clip_value=np.inf, |
|
input_collections=None, recurrent_collections=None): |
|
"""Create a GRU object. |
|
|
|
Args: |
|
num_units: Number of units in the GRU |
|
forget_bias (optional): Hack to help learning. |
|
input_weight_scale (optional): weights are scaled ws/sqrt(#inputs), with |
|
ws being the weight scale. |
|
rec_weight_scale (optional): weights are scaled ws/sqrt(#inputs), |
|
with ws being the weight scale. |
|
clip_value (optional): if the recurrent values grow above this value, |
|
clip them. |
|
input_collections (optional): List of additonal collections variables |
|
that input->rec weights should belong to. |
|
recurrent_collections (optional): List of additonal collections variables |
|
that rec->rec weights should belong to. |
|
""" |
|
self._num_units = num_units |
|
self._forget_bias = forget_bias |
|
self._input_weight_scale = input_weight_scale |
|
self._rec_weight_scale = rec_weight_scale |
|
self._clip_value = clip_value |
|
self._input_collections = input_collections |
|
self._rec_collections = recurrent_collections |
|
|
|
@property |
|
def state_size(self): |
|
return self._num_units |
|
|
|
@property |
|
def output_size(self): |
|
return self._num_units |
|
|
|
@property |
|
def state_multiplier(self): |
|
return 1 |
|
|
|
def output_from_state(self, state): |
|
"""Return the output portion of the state.""" |
|
return state |
|
|
|
def __call__(self, inputs, state, scope=None): |
|
"""Gated recurrent unit (GRU) function. |
|
|
|
Args: |
|
inputs: A 2D batch x input_dim tensor of inputs. |
|
state: The previous state from the last time step. |
|
scope (optional): TF variable scope for defined GRU variables. |
|
|
|
Returns: |
|
A tuple (state, state), where state is the newly computed state at time t. |
|
It is returned twice to respect an interface that works for LSTMs. |
|
""" |
|
|
|
x = inputs |
|
h = state |
|
with tf.variable_scope(scope or type(self).__name__): |
|
with tf.variable_scope("Gates"): |
|
|
|
r_x = u_x = 0.0 |
|
if x is not None: |
|
r_x, u_x = tf.split(axis=1, num_or_size_splits=2, value=linear(x, |
|
2 * self._num_units, |
|
alpha=self._input_weight_scale, |
|
do_bias=False, |
|
name="x_2_ru", |
|
normalized=False, |
|
collections=self._input_collections)) |
|
|
|
r_h, u_h = tf.split(axis=1, num_or_size_splits=2, value=linear(h, |
|
2 * self._num_units, |
|
do_bias=True, |
|
alpha=self._rec_weight_scale, |
|
name="h_2_ru", |
|
collections=self._rec_collections)) |
|
r = r_x + r_h |
|
u = u_x + u_h |
|
r, u = tf.sigmoid(r), tf.sigmoid(u + self._forget_bias) |
|
|
|
with tf.variable_scope("Candidate"): |
|
c_x = 0.0 |
|
if x is not None: |
|
c_x = linear(x, self._num_units, name="x_2_c", do_bias=False, |
|
alpha=self._input_weight_scale, |
|
normalized=False, |
|
collections=self._input_collections) |
|
c_rh = linear(r*h, self._num_units, name="rh_2_c", do_bias=True, |
|
alpha=self._rec_weight_scale, |
|
collections=self._rec_collections) |
|
c = tf.tanh(c_x + c_rh) |
|
|
|
new_h = u * h + (1 - u) * c |
|
new_h = tf.clip_by_value(new_h, -self._clip_value, self._clip_value) |
|
|
|
return new_h, new_h |
|
|
|
|
|
class LFADS(object): |
|
"""LFADS - Latent Factor Analysis via Dynamical Systems. |
|
|
|
LFADS is an unsupervised method to decompose time series data into |
|
various factors, such as an initial condition, a generative |
|
dynamical system, inferred inputs to that generator, and a low |
|
dimensional description of the observed data, called the factors. |
|
Additoinally, the observations have a noise model (in this case |
|
Poisson), so a denoised version of the observations is also created |
|
(e.g. underlying rates of a Poisson distribution given the observed |
|
event counts). |
|
""" |
|
|
|
def __init__(self, hps, kind="train", datasets=None): |
|
"""Create an LFADS model. |
|
|
|
train - a model for training, sampling of posteriors is used |
|
posterior_sample_and_average - sample from the posterior, this is used |
|
for evaluating the expected value of the outputs of LFADS, given a |
|
specific input, by averaging over multiple samples from the approx |
|
posterior. Also used for the lower bound on the negative |
|
log-likelihood using IWAE error (Importance Weighed Auto-encoder). |
|
This is the denoising operation. |
|
prior_sample - a model for generation - sampling from priors is used |
|
|
|
Args: |
|
hps: The dictionary of hyper parameters. |
|
kind: the type of model to build (see above). |
|
datasets: a dictionary of named data_dictionaries, see top of lfads.py |
|
""" |
|
print("Building graph...") |
|
all_kinds = ['train', 'posterior_sample_and_average', 'posterior_push_mean', |
|
'prior_sample'] |
|
assert kind in all_kinds, 'Wrong kind' |
|
if hps.feedback_factors_or_rates == "rates": |
|
assert len(hps.dataset_names) == 1, \ |
|
"Multiple datasets not supported for rate feedback." |
|
num_steps = hps.num_steps |
|
ic_dim = hps.ic_dim |
|
co_dim = hps.co_dim |
|
ext_input_dim = hps.ext_input_dim |
|
cell_class = GRU |
|
gen_cell_class = GenGRU |
|
|
|
def makelambda(v): |
|
return lambda: v |
|
|
|
|
|
|
|
self.dataName = tf.placeholder(tf.string, shape=()) |
|
|
|
|
|
|
|
if hps.output_dist == 'poisson': |
|
|
|
assert np.issubdtype( |
|
datasets[hps.dataset_names[0]]['train_data'].dtype, int), \ |
|
"Data dtype must be int for poisson output distribution" |
|
data_dtype = tf.int32 |
|
elif hps.output_dist == 'gaussian': |
|
assert np.issubdtype( |
|
datasets[hps.dataset_names[0]]['train_data'].dtype, float), \ |
|
"Data dtype must be float for gaussian output dsitribution" |
|
data_dtype = tf.float32 |
|
else: |
|
assert False, "NIY" |
|
self.dataset_ph = dataset_ph = tf.placeholder(data_dtype, |
|
[None, num_steps, None], |
|
name="data") |
|
self.train_step = tf.get_variable("global_step", [], tf.int64, |
|
tf.zeros_initializer(), |
|
trainable=False) |
|
self.hps = hps |
|
ndatasets = hps.ndatasets |
|
factors_dim = hps.factors_dim |
|
self.preds = preds = [None] * ndatasets |
|
self.fns_in_fac_Ws = fns_in_fac_Ws = [None] * ndatasets |
|
self.fns_in_fatcor_bs = fns_in_fac_bs = [None] * ndatasets |
|
self.fns_out_fac_Ws = fns_out_fac_Ws = [None] * ndatasets |
|
self.fns_out_fac_bs = fns_out_fac_bs = [None] * ndatasets |
|
self.datasetNames = dataset_names = hps.dataset_names |
|
self.ext_inputs = ext_inputs = None |
|
|
|
if len(dataset_names) == 1: |
|
if 'alignment_matrix_cxf' in datasets[dataset_names[0]].keys(): |
|
used_in_factors_dim = factors_dim |
|
in_identity_if_poss = False |
|
else: |
|
used_in_factors_dim = hps.dataset_dims[dataset_names[0]] |
|
in_identity_if_poss = True |
|
else: |
|
used_in_factors_dim = factors_dim |
|
in_identity_if_poss = False |
|
|
|
for d, name in enumerate(dataset_names): |
|
data_dim = hps.dataset_dims[name] |
|
in_mat_cxf = None |
|
in_bias_1xf = None |
|
align_bias_1xc = None |
|
|
|
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): |
|
dataset = datasets[name] |
|
if hps.do_train_readin: |
|
print("Initializing trainable readin matrix with alignment matrix" \ |
|
" provided for dataset:", name) |
|
else: |
|
print("Setting non-trainable readin matrix to alignment matrix" \ |
|
" provided for dataset:", name) |
|
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) |
|
if in_mat_cxf.shape != (data_dim, factors_dim): |
|
raise ValueError("""Alignment matrix must have dimensions %d x %d |
|
(data_dim x factors_dim), but currently has %d x %d."""% |
|
(data_dim, factors_dim, in_mat_cxf.shape[0], |
|
in_mat_cxf.shape[1])) |
|
if datasets and 'alignment_bias_c' in datasets[name].keys(): |
|
dataset = datasets[name] |
|
if hps.do_train_readin: |
|
print("Initializing trainable readin bias with alignment bias " \ |
|
"provided for dataset:", name) |
|
else: |
|
print("Setting non-trainable readin bias to alignment bias " \ |
|
"provided for dataset:", name) |
|
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) |
|
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) |
|
if align_bias_1xc.shape[1] != data_dim: |
|
raise ValueError("""Alignment bias must have dimensions %d |
|
(data_dim), but currently has %d."""% |
|
(data_dim, in_mat_cxf.shape[0])) |
|
if in_mat_cxf is not None and align_bias_1xc is not None: |
|
|
|
|
|
|
|
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf) |
|
|
|
if hps.do_train_readin: |
|
|
|
|
|
|
|
collections_readin=['IO_transformations'] |
|
else: |
|
collections_readin=None |
|
|
|
in_fac_lin = init_linear(data_dim, used_in_factors_dim, |
|
do_bias=True, |
|
mat_init_value=in_mat_cxf, |
|
bias_init_value=in_bias_1xf, |
|
identity_if_possible=in_identity_if_poss, |
|
normalized=False, name="x_2_infac_"+name, |
|
collections=collections_readin, |
|
trainable=hps.do_train_readin) |
|
in_fac_W, in_fac_b = in_fac_lin |
|
fns_in_fac_Ws[d] = makelambda(in_fac_W) |
|
fns_in_fac_bs[d] = makelambda(in_fac_b) |
|
|
|
with tf.variable_scope("glm"): |
|
out_identity_if_poss = False |
|
if len(dataset_names) == 1 and \ |
|
factors_dim == hps.dataset_dims[dataset_names[0]]: |
|
out_identity_if_poss = True |
|
for d, name in enumerate(dataset_names): |
|
data_dim = hps.dataset_dims[name] |
|
in_mat_cxf = None |
|
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): |
|
dataset = datasets[name] |
|
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) |
|
|
|
if datasets and 'alignment_bias_c' in datasets[name].keys(): |
|
dataset = datasets[name] |
|
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) |
|
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) |
|
|
|
out_mat_fxc = None |
|
out_bias_1xc = None |
|
if in_mat_cxf is not None: |
|
out_mat_fxc = in_mat_cxf.T |
|
if align_bias_1xc is not None: |
|
out_bias_1xc = align_bias_1xc |
|
|
|
if hps.output_dist == 'poisson': |
|
out_fac_lin = init_linear(factors_dim, data_dim, do_bias=True, |
|
mat_init_value=out_mat_fxc, |
|
bias_init_value=out_bias_1xc, |
|
identity_if_possible=out_identity_if_poss, |
|
normalized=False, |
|
name="fac_2_logrates_"+name, |
|
collections=['IO_transformations']) |
|
out_fac_W, out_fac_b = out_fac_lin |
|
|
|
elif hps.output_dist == 'gaussian': |
|
out_fac_lin_mean = \ |
|
init_linear(factors_dim, data_dim, do_bias=True, |
|
mat_init_value=out_mat_fxc, |
|
bias_init_value=out_bias_1xc, |
|
normalized=False, |
|
name="fac_2_means_"+name, |
|
collections=['IO_transformations']) |
|
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean |
|
|
|
mat_init_value = np.zeros([factors_dim, data_dim]).astype(np.float32) |
|
bias_init_value = np.ones([1, data_dim]).astype(np.float32) |
|
out_fac_lin_logvar = \ |
|
init_linear(factors_dim, data_dim, do_bias=True, |
|
mat_init_value=mat_init_value, |
|
bias_init_value=bias_init_value, |
|
normalized=False, |
|
name="fac_2_logvars_"+name, |
|
collections=['IO_transformations']) |
|
out_fac_W_mean, out_fac_b_mean = out_fac_lin_mean |
|
out_fac_W_logvar, out_fac_b_logvar = out_fac_lin_logvar |
|
out_fac_W = tf.concat( |
|
axis=1, values=[out_fac_W_mean, out_fac_W_logvar]) |
|
out_fac_b = tf.concat( |
|
axis=1, values=[out_fac_b_mean, out_fac_b_logvar]) |
|
else: |
|
assert False, "NIY" |
|
|
|
preds[d] = tf.equal(tf.constant(name), self.dataName) |
|
data_dim = hps.dataset_dims[name] |
|
fns_out_fac_Ws[d] = makelambda(out_fac_W) |
|
fns_out_fac_bs[d] = makelambda(out_fac_b) |
|
|
|
pf_pairs_in_fac_Ws = zip(preds, fns_in_fac_Ws) |
|
pf_pairs_in_fac_bs = zip(preds, fns_in_fac_bs) |
|
pf_pairs_out_fac_Ws = zip(preds, fns_out_fac_Ws) |
|
pf_pairs_out_fac_bs = zip(preds, fns_out_fac_bs) |
|
|
|
this_in_fac_W = tf.case(pf_pairs_in_fac_Ws, exclusive=True) |
|
this_in_fac_b = tf.case(pf_pairs_in_fac_bs, exclusive=True) |
|
this_out_fac_W = tf.case(pf_pairs_out_fac_Ws, exclusive=True) |
|
this_out_fac_b = tf.case(pf_pairs_out_fac_bs, exclusive=True) |
|
|
|
|
|
if hps.ext_input_dim > 0: |
|
self.ext_input = tf.placeholder(tf.float32, |
|
[None, num_steps, ext_input_dim], |
|
name="ext_input") |
|
else: |
|
self.ext_input = None |
|
ext_input_bxtxi = self.ext_input |
|
|
|
self.keep_prob = keep_prob = tf.placeholder(tf.float32, [], "keep_prob") |
|
self.batch_size = batch_size = int(hps.batch_size) |
|
self.learning_rate = tf.Variable(float(hps.learning_rate_init), |
|
trainable=False, name="learning_rate") |
|
self.learning_rate_decay_op = self.learning_rate.assign( |
|
self.learning_rate * hps.learning_rate_decay_factor) |
|
|
|
|
|
dataset_do_bxtxd = tf.nn.dropout(tf.to_float(dataset_ph), keep_prob) |
|
if hps.ext_input_dim > 0: |
|
ext_input_do_bxtxi = tf.nn.dropout(ext_input_bxtxi, keep_prob) |
|
else: |
|
ext_input_do_bxtxi = None |
|
|
|
|
|
def encode_data(dataset_bxtxd, enc_cell, name, forward_or_reverse, |
|
num_steps_to_encode): |
|
"""Encode data for LFADS |
|
Args: |
|
dataset_bxtxd - the data to encode, as a 3 tensor, with dims |
|
time x batch x data dims. |
|
enc_cell: encoder cell |
|
name: name of encoder |
|
forward_or_reverse: string, encode in forward or reverse direction |
|
num_steps_to_encode: number of steps to encode, 0:num_steps_to_encode |
|
Returns: |
|
encoded data as a list with num_steps_to_encode items, in order |
|
""" |
|
if forward_or_reverse == "forward": |
|
dstr = "_fwd" |
|
time_fwd_or_rev = range(num_steps_to_encode) |
|
else: |
|
dstr = "_rev" |
|
time_fwd_or_rev = reversed(range(num_steps_to_encode)) |
|
|
|
with tf.variable_scope(name+"_enc"+dstr, reuse=False): |
|
enc_state = tf.tile( |
|
tf.Variable(tf.zeros([1, enc_cell.state_size]), |
|
name=name+"_enc_t0"+dstr), tf.stack([batch_size, 1])) |
|
enc_state.set_shape([None, enc_cell.state_size]) |
|
|
|
enc_outs = [None] * num_steps_to_encode |
|
for i, t in enumerate(time_fwd_or_rev): |
|
with tf.variable_scope(name+"_enc"+dstr, reuse=True if i > 0 else None): |
|
dataset_t_bxd = dataset_bxtxd[:,t,:] |
|
in_fac_t_bxf = tf.matmul(dataset_t_bxd, this_in_fac_W) + this_in_fac_b |
|
in_fac_t_bxf.set_shape([None, used_in_factors_dim]) |
|
if ext_input_dim > 0 and not hps.inject_ext_input_to_gen: |
|
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:] |
|
enc_input_t_bxfpe = tf.concat( |
|
axis=1, values=[in_fac_t_bxf, ext_input_t_bxi]) |
|
else: |
|
enc_input_t_bxfpe = in_fac_t_bxf |
|
enc_out, enc_state = enc_cell(enc_input_t_bxfpe, enc_state) |
|
enc_outs[t] = enc_out |
|
|
|
return enc_outs |
|
|
|
|
|
|
|
self.ic_enc_fwd = [None] * num_steps |
|
self.ic_enc_rev = [None] * num_steps |
|
if ic_dim > 0: |
|
enc_ic_cell = cell_class(hps.ic_enc_dim, |
|
weight_scale=hps.cell_weight_scale, |
|
clip_value=hps.cell_clip_value) |
|
ic_enc_fwd = encode_data(dataset_do_bxtxd, enc_ic_cell, |
|
"ic", "forward", |
|
hps.num_steps_for_gen_ic) |
|
ic_enc_rev = encode_data(dataset_do_bxtxd, enc_ic_cell, |
|
"ic", "reverse", |
|
hps.num_steps_for_gen_ic) |
|
self.ic_enc_fwd = ic_enc_fwd |
|
self.ic_enc_rev = ic_enc_rev |
|
|
|
|
|
|
|
self.ci_enc_fwd = [None] * num_steps |
|
self.ci_enc_rev = [None] * num_steps |
|
if co_dim > 0: |
|
enc_ci_cell = cell_class(hps.ci_enc_dim, |
|
weight_scale=hps.cell_weight_scale, |
|
clip_value=hps.cell_clip_value) |
|
ci_enc_fwd = encode_data(dataset_do_bxtxd, enc_ci_cell, |
|
"ci", "forward", |
|
hps.num_steps) |
|
if hps.do_causal_controller: |
|
ci_enc_rev = None |
|
else: |
|
ci_enc_rev = encode_data(dataset_do_bxtxd, enc_ci_cell, |
|
"ci", "reverse", |
|
hps.num_steps) |
|
self.ci_enc_fwd = ci_enc_fwd |
|
self.ci_enc_rev = ci_enc_rev |
|
|
|
|
|
|
|
|
|
with tf.variable_scope("z", reuse=False): |
|
self.prior_zs_g0 = None |
|
self.posterior_zs_g0 = None |
|
self.g0s_val = None |
|
if ic_dim > 0: |
|
self.prior_zs_g0 = \ |
|
LearnableDiagonalGaussian(batch_size, ic_dim, name="prior_g0", |
|
mean_init=0.0, |
|
var_min=hps.ic_prior_var_min, |
|
var_init=hps.ic_prior_var_scale, |
|
var_max=hps.ic_prior_var_max) |
|
ic_enc = tf.concat(axis=1, values=[ic_enc_fwd[-1], ic_enc_rev[0]]) |
|
ic_enc = tf.nn.dropout(ic_enc, keep_prob) |
|
self.posterior_zs_g0 = \ |
|
DiagonalGaussianFromInput(ic_enc, ic_dim, "ic_enc_2_post_g0", |
|
var_min=hps.ic_post_var_min) |
|
if kind in ["train", "posterior_sample_and_average", |
|
"posterior_push_mean"]: |
|
zs_g0 = self.posterior_zs_g0 |
|
else: |
|
zs_g0 = self.prior_zs_g0 |
|
if kind in ["train", "posterior_sample_and_average", "prior_sample"]: |
|
self.g0s_val = zs_g0.sample |
|
else: |
|
self.g0s_val = zs_g0.mean |
|
|
|
|
|
self.prior_zs_co = prior_zs_co = [None] * num_steps |
|
self.posterior_zs_co = posterior_zs_co = [None] * num_steps |
|
self.zs_co = zs_co = [None] * num_steps |
|
self.prior_zs_ar_con = None |
|
if co_dim > 0: |
|
|
|
autocorrelation_taus = [hps.prior_ar_atau for x in range(hps.co_dim)] |
|
noise_variances = [hps.prior_ar_nvar for x in range(hps.co_dim)] |
|
self.prior_zs_ar_con = prior_zs_ar_con = \ |
|
LearnableAutoRegressive1Prior(batch_size, hps.co_dim, |
|
autocorrelation_taus, |
|
noise_variances, |
|
hps.do_train_prior_ar_atau, |
|
hps.do_train_prior_ar_nvar, |
|
num_steps, "u_prior_ar1") |
|
|
|
|
|
|
|
self.controller_outputs = u_t = [None] * num_steps |
|
self.con_ics = con_state = None |
|
self.con_states = con_states = [None] * num_steps |
|
self.con_outs = con_outs = [None] * num_steps |
|
self.gen_inputs = gen_inputs = [None] * num_steps |
|
if co_dim > 0: |
|
|
|
|
|
con_cell = gen_cell_class(hps.con_dim, |
|
input_weight_scale=hps.cell_weight_scale, |
|
rec_weight_scale=hps.cell_weight_scale, |
|
clip_value=hps.cell_clip_value, |
|
recurrent_collections=['l2_con_reg']) |
|
with tf.variable_scope("con", reuse=False): |
|
self.con_ics = tf.tile( |
|
tf.Variable(tf.zeros([1, hps.con_dim*con_cell.state_multiplier]), |
|
name="c0"), |
|
tf.stack([batch_size, 1])) |
|
self.con_ics.set_shape([None, con_cell.state_size]) |
|
con_states[-1] = self.con_ics |
|
|
|
gen_cell = gen_cell_class(hps.gen_dim, |
|
input_weight_scale=hps.gen_cell_input_weight_scale, |
|
rec_weight_scale=hps.gen_cell_rec_weight_scale, |
|
clip_value=hps.cell_clip_value, |
|
recurrent_collections=['l2_gen_reg']) |
|
with tf.variable_scope("gen", reuse=False): |
|
if ic_dim == 0: |
|
self.gen_ics = tf.tile( |
|
tf.Variable(tf.zeros([1, gen_cell.state_size]), name="g0"), |
|
tf.stack([batch_size, 1])) |
|
else: |
|
self.gen_ics = linear(self.g0s_val, gen_cell.state_size, |
|
identity_if_possible=True, |
|
name="g0_2_gen_ic") |
|
|
|
self.gen_states = gen_states = [None] * num_steps |
|
self.gen_outs = gen_outs = [None] * num_steps |
|
gen_states[-1] = self.gen_ics |
|
gen_outs[-1] = gen_cell.output_from_state(gen_states[-1]) |
|
self.factors = factors = [None] * num_steps |
|
factors[-1] = linear(gen_outs[-1], factors_dim, do_bias=False, |
|
normalized=True, name="gen_2_fac") |
|
|
|
self.rates = rates = [None] * num_steps |
|
|
|
with tf.variable_scope("glm", reuse=False): |
|
if hps.output_dist == 'poisson': |
|
log_rates_t0 = tf.matmul(factors[-1], this_out_fac_W) + this_out_fac_b |
|
log_rates_t0.set_shape([None, None]) |
|
rates[-1] = tf.exp(log_rates_t0) |
|
rates[-1].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]]) |
|
elif hps.output_dist == 'gaussian': |
|
mean_n_logvars = tf.matmul(factors[-1],this_out_fac_W) + this_out_fac_b |
|
mean_n_logvars.set_shape([None, None]) |
|
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2, |
|
value=mean_n_logvars) |
|
rates[-1] = means_t_bxd |
|
else: |
|
assert False, "NIY" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.output_dist_params = dist_params = [None] * num_steps |
|
self.log_p_xgz_b = log_p_xgz_b = 0.0 |
|
for t in range(num_steps): |
|
|
|
if co_dim > 0: |
|
|
|
tlag = t - hps.controller_input_lag |
|
if tlag < 0: |
|
con_in_f_t = tf.zeros_like(ci_enc_fwd[0]) |
|
else: |
|
con_in_f_t = ci_enc_fwd[tlag] |
|
if hps.do_causal_controller: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
con_in_list_t = [con_in_f_t] |
|
else: |
|
tlag_rev = t + hps.controller_input_lag |
|
if tlag_rev >= num_steps: |
|
|
|
con_in_r_t = tf.zeros_like(ci_enc_rev[0]) |
|
else: |
|
con_in_r_t = ci_enc_rev[tlag_rev] |
|
con_in_list_t = [con_in_f_t, con_in_r_t] |
|
|
|
if hps.do_feed_factors_to_controller: |
|
if hps.feedback_factors_or_rates == "factors": |
|
con_in_list_t.append(factors[t-1]) |
|
elif hps.feedback_factors_or_rates == "rates": |
|
con_in_list_t.append(rates[t-1]) |
|
else: |
|
assert False, "NIY" |
|
|
|
con_in_t = tf.concat(axis=1, values=con_in_list_t) |
|
con_in_t = tf.nn.dropout(con_in_t, keep_prob) |
|
with tf.variable_scope("con", reuse=True if t > 0 else None): |
|
con_outs[t], con_states[t] = con_cell(con_in_t, con_states[t-1]) |
|
posterior_zs_co[t] = \ |
|
DiagonalGaussianFromInput(con_outs[t], co_dim, |
|
name="con_to_post_co") |
|
if kind == "train": |
|
u_t[t] = posterior_zs_co[t].sample |
|
elif kind == "posterior_sample_and_average": |
|
u_t[t] = posterior_zs_co[t].sample |
|
elif kind == "posterior_push_mean": |
|
u_t[t] = posterior_zs_co[t].mean |
|
else: |
|
u_t[t] = prior_zs_ar_con.samples_t[t] |
|
|
|
|
|
if ext_input_dim > 0 and hps.inject_ext_input_to_gen: |
|
ext_input_t_bxi = ext_input_do_bxtxi[:,t,:] |
|
if co_dim > 0: |
|
gen_inputs[t] = tf.concat(axis=1, values=[u_t[t], ext_input_t_bxi]) |
|
else: |
|
gen_inputs[t] = ext_input_t_bxi |
|
else: |
|
gen_inputs[t] = u_t[t] |
|
|
|
|
|
data_t_bxd = dataset_ph[:,t,:] |
|
with tf.variable_scope("gen", reuse=True if t > 0 else None): |
|
gen_outs[t], gen_states[t] = gen_cell(gen_inputs[t], gen_states[t-1]) |
|
gen_outs[t] = tf.nn.dropout(gen_outs[t], keep_prob) |
|
with tf.variable_scope("gen", reuse=True): |
|
factors[t] = linear(gen_outs[t], factors_dim, do_bias=False, |
|
normalized=True, name="gen_2_fac") |
|
with tf.variable_scope("glm", reuse=True if t > 0 else None): |
|
if hps.output_dist == 'poisson': |
|
log_rates_t = tf.matmul(factors[t], this_out_fac_W) + this_out_fac_b |
|
log_rates_t.set_shape([None, None]) |
|
rates[t] = dist_params[t] = tf.exp(tf.clip_by_value(log_rates_t, -hps._clip_value, hps._clip_value)) |
|
rates[t].set_shape([None, hps.dataset_dims[hps.dataset_names[0]]]) |
|
loglikelihood_t = Poisson(log_rates_t).logp(data_t_bxd) |
|
|
|
elif hps.output_dist == 'gaussian': |
|
mean_n_logvars = tf.matmul(factors[t],this_out_fac_W) + this_out_fac_b |
|
mean_n_logvars.set_shape([None, None]) |
|
means_t_bxd, logvars_t_bxd = tf.split(axis=1, num_or_size_splits=2, |
|
value=mean_n_logvars) |
|
rates[t] = means_t_bxd |
|
dist_params[t] = tf.concat( |
|
axis=1, values=[means_t_bxd, tf.exp(tf.clip_by_value(logvars_t_bxd, -hps._clip_value, hps._clip_value))]) |
|
loglikelihood_t = \ |
|
diag_gaussian_log_likelihood(data_t_bxd, |
|
means_t_bxd, logvars_t_bxd) |
|
else: |
|
assert False, "NIY" |
|
|
|
log_p_xgz_b += tf.reduce_sum(loglikelihood_t, [1]) |
|
|
|
|
|
self.corr_cost = tf.constant(0.0) |
|
if hps.co_mean_corr_scale > 0.0: |
|
all_sum_corr = [] |
|
for i in range(hps.co_dim): |
|
for j in range(i+1, hps.co_dim): |
|
sum_corr_ij = tf.constant(0.0) |
|
for t in range(num_steps): |
|
u_mean_t = posterior_zs_co[t].mean |
|
sum_corr_ij += u_mean_t[:,i]*u_mean_t[:,j] |
|
all_sum_corr.append(0.5 * tf.square(sum_corr_ij)) |
|
self.corr_cost = tf.reduce_mean(all_sum_corr) |
|
|
|
|
|
|
|
|
|
kl_cost_g0_b = tf.zeros_like(batch_size, dtype=tf.float32) |
|
kl_cost_co_b = tf.zeros_like(batch_size, dtype=tf.float32) |
|
self.kl_cost = tf.constant(0.0) |
|
self.recon_cost = tf.constant(0.0) |
|
self.nll_bound_vae = tf.constant(0.0) |
|
self.nll_bound_iwae = tf.constant(0.0) |
|
if kind in ["train", "posterior_sample_and_average", "posterior_push_mean"]: |
|
kl_cost_g0_b = 0.0 |
|
kl_cost_co_b = 0.0 |
|
if ic_dim > 0: |
|
g0_priors = [self.prior_zs_g0] |
|
g0_posts = [self.posterior_zs_g0] |
|
kl_cost_g0_b = KLCost_GaussianGaussian(g0_posts, g0_priors).kl_cost_b |
|
kl_cost_g0_b = hps.kl_ic_weight * kl_cost_g0_b |
|
if co_dim > 0: |
|
kl_cost_co_b = \ |
|
KLCost_GaussianGaussianProcessSampled( |
|
posterior_zs_co, prior_zs_ar_con).kl_cost_b |
|
kl_cost_co_b = hps.kl_co_weight * kl_cost_co_b |
|
|
|
|
|
|
|
|
|
self.recon_cost = - tf.reduce_mean(log_p_xgz_b) |
|
self.kl_cost = tf.reduce_mean(kl_cost_g0_b + kl_cost_co_b) |
|
|
|
lb_on_ll_b = log_p_xgz_b - kl_cost_g0_b - kl_cost_co_b |
|
|
|
|
|
self.nll_bound_vae = -tf.reduce_mean(lb_on_ll_b) |
|
|
|
|
|
k = tf.cast(tf.shape(log_p_xgz_b)[0], tf.float32) |
|
iwae_lb_on_ll = -tf.log(k) + log_sum_exp(lb_on_ll_b) |
|
self.nll_bound_iwae = -iwae_lb_on_ll |
|
|
|
|
|
self.l2_cost = tf.constant(0.0) |
|
if self.hps.l2_gen_scale > 0.0 or self.hps.l2_con_scale > 0.0: |
|
l2_costs = [] |
|
l2_numels = [] |
|
l2_reg_var_lists = [tf.get_collection('l2_gen_reg'), |
|
tf.get_collection('l2_con_reg')] |
|
l2_reg_scales = [self.hps.l2_gen_scale, self.hps.l2_con_scale] |
|
for l2_reg_vars, l2_scale in zip(l2_reg_var_lists, l2_reg_scales): |
|
for v in l2_reg_vars: |
|
numel = tf.reduce_prod(tf.concat(axis=0, values=tf.shape(v))) |
|
numel_f = tf.cast(numel, tf.float32) |
|
l2_numels.append(numel_f) |
|
v_l2 = tf.reduce_sum(v*v) |
|
l2_costs.append(0.5 * l2_scale * v_l2) |
|
self.l2_cost = tf.add_n(l2_costs) / tf.add_n(l2_numels) |
|
|
|
|
|
|
|
|
|
|
|
self.kl_decay_step = tf.maximum(self.train_step - hps.kl_start_step, 0) |
|
self.l2_decay_step = tf.maximum(self.train_step - hps.l2_start_step, 0) |
|
kl_decay_step_f = tf.cast(self.kl_decay_step, tf.float32) |
|
l2_decay_step_f = tf.cast(self.l2_decay_step, tf.float32) |
|
kl_increase_steps_f = tf.cast(hps.kl_increase_steps, tf.float32) |
|
l2_increase_steps_f = tf.cast(hps.l2_increase_steps, tf.float32) |
|
self.kl_weight = kl_weight = \ |
|
tf.minimum(kl_decay_step_f / kl_increase_steps_f, 1.0) |
|
self.l2_weight = l2_weight = \ |
|
tf.minimum(l2_decay_step_f / l2_increase_steps_f, 1.0) |
|
|
|
self.timed_kl_cost = kl_weight * self.kl_cost |
|
self.timed_l2_cost = l2_weight * self.l2_cost |
|
self.weight_corr_cost = hps.co_mean_corr_scale * self.corr_cost |
|
self.cost = self.recon_cost + self.timed_kl_cost + \ |
|
self.timed_l2_cost + self.weight_corr_cost |
|
|
|
if kind != "train": |
|
|
|
self.seso_saver = tf.train.Saver(tf.global_variables(), |
|
max_to_keep=hps.max_ckpt_to_keep) |
|
|
|
self.lve_saver = tf.train.Saver(tf.global_variables(), |
|
max_to_keep=hps.max_ckpt_to_keep_lve) |
|
|
|
return |
|
|
|
|
|
|
|
if self.hps.do_train_io_only: |
|
self.train_vars = tvars = \ |
|
tf.get_collection('IO_transformations', |
|
scope=tf.get_variable_scope().name) |
|
|
|
elif self.hps.do_train_encoder_only: |
|
tvars1 = \ |
|
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, |
|
scope='LFADS/ic_enc_*') |
|
tvars2 = \ |
|
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, |
|
scope='LFADS/z/ic_enc_*') |
|
|
|
self.train_vars = tvars = tvars1 + tvars2 |
|
|
|
else: |
|
self.train_vars = tvars = \ |
|
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, |
|
scope=tf.get_variable_scope().name) |
|
print("done.") |
|
print("Model Variables (to be optimized): ") |
|
total_params = 0 |
|
for i in range(len(tvars)): |
|
shape = tvars[i].get_shape().as_list() |
|
print(" ", i, tvars[i].name, shape) |
|
total_params += np.prod(shape) |
|
print("Total model parameters: ", total_params) |
|
|
|
grads = tf.gradients(self.cost, tvars) |
|
grads, grad_global_norm = tf.clip_by_global_norm(grads, hps.max_grad_norm) |
|
opt = tf.train.AdamOptimizer(self.learning_rate, beta1=0.9, beta2=0.999, |
|
epsilon=1e-01) |
|
self.grads = grads |
|
self.grad_global_norm = grad_global_norm |
|
self.train_op = opt.apply_gradients( |
|
zip(grads, tvars), global_step=self.train_step) |
|
|
|
self.seso_saver = tf.train.Saver(tf.global_variables(), |
|
max_to_keep=hps.max_ckpt_to_keep) |
|
|
|
|
|
self.lve_saver = tf.train.Saver(tf.global_variables(), |
|
max_to_keep=hps.max_ckpt_to_keep) |
|
|
|
|
|
|
|
self.example_image = tf.placeholder(tf.float32, shape=[1,None,None,3], |
|
name='image_tensor') |
|
self.example_summ = tf.summary.image("LFADS example", self.example_image, |
|
collections=["example_summaries"]) |
|
|
|
|
|
self.lr_summ = tf.summary.scalar("Learning rate", self.learning_rate) |
|
self.kl_weight_summ = tf.summary.scalar("KL weight", self.kl_weight) |
|
self.l2_weight_summ = tf.summary.scalar("L2 weight", self.l2_weight) |
|
self.corr_cost_summ = tf.summary.scalar("Corr cost", self.weight_corr_cost) |
|
self.grad_global_norm_summ = tf.summary.scalar("Gradient global norm", |
|
self.grad_global_norm) |
|
if hps.co_dim > 0: |
|
self.atau_summ = [None] * hps.co_dim |
|
self.pvar_summ = [None] * hps.co_dim |
|
for c in range(hps.co_dim): |
|
self.atau_summ[c] = \ |
|
tf.summary.scalar("AR Autocorrelation taus " + str(c), |
|
tf.exp(self.prior_zs_ar_con.logataus_1xu[0,c])) |
|
self.pvar_summ[c] = \ |
|
tf.summary.scalar("AR Variances " + str(c), |
|
tf.exp(self.prior_zs_ar_con.logpvars_1xu[0,c])) |
|
|
|
|
|
|
|
|
|
|
|
kl_cost_ph = tf.placeholder(tf.float32, shape=[], name='kl_cost_ph') |
|
self.kl_t_cost_summ = tf.summary.scalar("KL cost (train)", kl_cost_ph, |
|
collections=["train_summaries"]) |
|
self.kl_v_cost_summ = tf.summary.scalar("KL cost (valid)", kl_cost_ph, |
|
collections=["valid_summaries"]) |
|
l2_cost_ph = tf.placeholder(tf.float32, shape=[], name='l2_cost_ph') |
|
self.l2_cost_summ = tf.summary.scalar("L2 cost", l2_cost_ph, |
|
collections=["train_summaries"]) |
|
|
|
recon_cost_ph = tf.placeholder(tf.float32, shape=[], name='recon_cost_ph') |
|
self.recon_t_cost_summ = tf.summary.scalar("Reconstruction cost (train)", |
|
recon_cost_ph, |
|
collections=["train_summaries"]) |
|
self.recon_v_cost_summ = tf.summary.scalar("Reconstruction cost (valid)", |
|
recon_cost_ph, |
|
collections=["valid_summaries"]) |
|
|
|
total_cost_ph = tf.placeholder(tf.float32, shape=[], name='total_cost_ph') |
|
self.cost_t_summ = tf.summary.scalar("Total cost (train)", total_cost_ph, |
|
collections=["train_summaries"]) |
|
self.cost_v_summ = tf.summary.scalar("Total cost (valid)", total_cost_ph, |
|
collections=["valid_summaries"]) |
|
|
|
self.kl_cost_ph = kl_cost_ph |
|
self.l2_cost_ph = l2_cost_ph |
|
self.recon_cost_ph = recon_cost_ph |
|
self.total_cost_ph = total_cost_ph |
|
|
|
|
|
self.merged_examples = tf.summary.merge_all(key="example_summaries") |
|
self.merged_generic = tf.summary.merge_all() |
|
self.merged_train = tf.summary.merge_all(key="train_summaries") |
|
self.merged_valid = tf.summary.merge_all(key="valid_summaries") |
|
|
|
session = tf.get_default_session() |
|
self.logfile = os.path.join(hps.lfads_save_dir, "lfads_log") |
|
self.writer = tf.summary.FileWriter(self.logfile) |
|
|
|
def build_feed_dict(self, train_name, data_bxtxd, ext_input_bxtxi=None, |
|
keep_prob=None): |
|
"""Build the feed dictionary, handles cases where there is no value defined. |
|
|
|
Args: |
|
train_name: The key into the datasets, to set the tf.case statement for |
|
the proper readin / readout matrices. |
|
data_bxtxd: The data tensor |
|
ext_input_bxtxi (optional): The external input tensor |
|
keep_prob: The drop out keep probability. |
|
|
|
Returns: |
|
The feed dictionary with TF tensors as keys and data as values, for use |
|
with tf.Session.run() |
|
|
|
""" |
|
feed_dict = {} |
|
B, T, _ = data_bxtxd.shape |
|
feed_dict[self.dataName] = train_name |
|
feed_dict[self.dataset_ph] = data_bxtxd |
|
|
|
if self.ext_input is not None and ext_input_bxtxi is not None: |
|
feed_dict[self.ext_input] = ext_input_bxtxi |
|
|
|
if keep_prob is None: |
|
feed_dict[self.keep_prob] = self.hps.keep_prob |
|
else: |
|
feed_dict[self.keep_prob] = keep_prob |
|
|
|
return feed_dict |
|
|
|
@staticmethod |
|
def get_batch(data_extxd, ext_input_extxi=None, batch_size=None, |
|
example_idxs=None): |
|
"""Get a batch of data, either randomly chosen, or specified directly. |
|
|
|
Args: |
|
data_extxd: The data to model, numpy tensors with shape: |
|
# examples x # time steps x # dimensions |
|
ext_input_extxi (optional): The external inputs, numpy tensor with shape: |
|
# examples x # time steps x # external input dimensions |
|
batch_size: The size of the batch to return |
|
example_idxs (optional): The example indices used to select examples. |
|
|
|
Returns: |
|
A tuple with two parts: |
|
1. Batched data numpy tensor with shape: |
|
batch_size x # time steps x # dimensions |
|
2. Batched external input numpy tensor with shape: |
|
batch_size x # time steps x # external input dims |
|
""" |
|
assert batch_size is not None or example_idxs is not None, "Problems" |
|
E, T, D = data_extxd.shape |
|
if example_idxs is None: |
|
example_idxs = np.random.choice(E, batch_size) |
|
|
|
ext_input_bxtxi = None |
|
if ext_input_extxi is not None: |
|
ext_input_bxtxi = ext_input_extxi[example_idxs,:,:] |
|
|
|
return data_extxd[example_idxs,:,:], ext_input_bxtxi |
|
|
|
@staticmethod |
|
def example_idxs_mod_batch_size(nexamples, batch_size): |
|
"""Given a number of examples, E, and a batch_size, B, generate indices |
|
[0, 1, 2, ... B-1; |
|
[B, B+1, ... 2*B-1; |
|
... |
|
] |
|
returning those indices as a 2-dim tensor shaped like E/B x B. Note that |
|
shape is only correct if E % B == 0. If not, then an extra row is generated |
|
so that the remainder of examples is included. The extra examples are |
|
explicitly to to the zero index (see randomize_example_idxs_mod_batch_size) |
|
for randomized behavior. |
|
|
|
Args: |
|
nexamples: The number of examples to batch up. |
|
batch_size: The size of the batch. |
|
Returns: |
|
2-dim tensor as described above. |
|
""" |
|
bmrem = batch_size - (nexamples % batch_size) |
|
bmrem_examples = [] |
|
if bmrem < batch_size: |
|
|
|
ridxs = np.random.permutation(nexamples)[0:bmrem].astype(np.int32) |
|
bmrem_examples = np.sort(ridxs) |
|
example_idxs = range(nexamples) + list(bmrem_examples) |
|
example_idxs_e_x_edivb = np.reshape(example_idxs, [-1, batch_size]) |
|
return example_idxs_e_x_edivb, bmrem |
|
|
|
@staticmethod |
|
def randomize_example_idxs_mod_batch_size(nexamples, batch_size): |
|
"""Indices 1:nexamples, randomized, in 2D form of |
|
shape = (nexamples / batch_size) x batch_size. The remainder |
|
is managed by drawing randomly from 1:nexamples. |
|
|
|
Args: |
|
nexamples: number of examples to randomize |
|
batch_size: number of elements in batch |
|
|
|
Returns: |
|
The randomized, properly shaped indicies. |
|
""" |
|
assert nexamples > batch_size, "Problems" |
|
bmrem = batch_size - nexamples % batch_size |
|
bmrem_examples = [] |
|
if bmrem < batch_size: |
|
bmrem_examples = np.random.choice(range(nexamples), |
|
size=bmrem, replace=False) |
|
example_idxs = range(nexamples) + list(bmrem_examples) |
|
mixed_example_idxs = np.random.permutation(example_idxs) |
|
example_idxs_e_x_edivb = np.reshape(mixed_example_idxs, [-1, batch_size]) |
|
return example_idxs_e_x_edivb, bmrem |
|
|
|
def shuffle_spikes_in_time(self, data_bxtxd): |
|
"""Shuffle the spikes in the temporal dimension. This is useful to |
|
help the LFADS system avoid overfitting to individual spikes or fast |
|
oscillations found in the data that are irrelevant to behavior. A |
|
pure 'tabula rasa' approach would avoid this, but LFADS is sensitive |
|
enough to pick up dynamics that you may not want. |
|
|
|
Args: |
|
data_bxtxd: numpy array of spike count data to be shuffled. |
|
Returns: |
|
S_bxtxd, a numpy array with the same dimensions and contents as |
|
data_bxtxd, but shuffled appropriately. |
|
|
|
""" |
|
|
|
B, T, N = data_bxtxd.shape |
|
w = self.hps.temporal_spike_jitter_width |
|
|
|
if w == 0: |
|
return data_bxtxd |
|
|
|
max_counts = np.max(data_bxtxd) |
|
S_bxtxd = np.zeros([B,T,N]) |
|
|
|
|
|
|
|
for mc in range(1,max_counts+1): |
|
idxs = np.nonzero(data_bxtxd >= mc) |
|
|
|
data_ones = np.zeros_like(data_bxtxd) |
|
data_ones[data_bxtxd >= mc] = 1 |
|
|
|
nfound = len(idxs[0]) |
|
shuffles_incrs_in_time = np.random.randint(-w, w, size=nfound) |
|
|
|
shuffle_tidxs = idxs[1].copy() |
|
shuffle_tidxs += shuffles_incrs_in_time |
|
|
|
|
|
shuffle_tidxs[shuffle_tidxs < 0] = -shuffle_tidxs[shuffle_tidxs < 0] |
|
shuffle_tidxs[shuffle_tidxs > T-1] = \ |
|
(T-1)-(shuffle_tidxs[shuffle_tidxs > T-1] -(T-1)) |
|
|
|
for iii in zip(idxs[0], shuffle_tidxs, idxs[2]): |
|
S_bxtxd[iii] += 1 |
|
|
|
return S_bxtxd |
|
|
|
def shuffle_and_flatten_datasets(self, datasets, kind='train'): |
|
"""Since LFADS supports multiple datasets in the same dynamical model, |
|
we have to be careful to use all the data in a single training epoch. But |
|
since the datasets my have different data dimensionality, we cannot batch |
|
examples from data dictionaries together. Instead, we generate random |
|
batches within each data dictionary, and then randomize these batches |
|
while holding onto the dataname, so that when it's time to feed |
|
the graph, the correct in/out matrices can be selected, per batch. |
|
|
|
Args: |
|
datasets: A dict of data dicts. The dataset dict is simply a |
|
name(string)-> data dictionary mapping (See top of lfads.py). |
|
kind: 'train' or 'valid' |
|
|
|
Returns: |
|
A flat list, in which each element is a pair ('name', indices). |
|
""" |
|
batch_size = self.hps.batch_size |
|
ndatasets = len(datasets) |
|
random_example_idxs = {} |
|
epoch_idxs = {} |
|
all_name_example_idx_pairs = [] |
|
kind_data = kind + '_data' |
|
for name, data_dict in datasets.items(): |
|
nexamples, ntime, data_dim = data_dict[kind_data].shape |
|
epoch_idxs[name] = 0 |
|
random_example_idxs, _ = \ |
|
self.randomize_example_idxs_mod_batch_size(nexamples, batch_size) |
|
|
|
epoch_size = random_example_idxs.shape[0] |
|
names = [name] * epoch_size |
|
all_name_example_idx_pairs += zip(names, random_example_idxs) |
|
|
|
np.random.shuffle(all_name_example_idx_pairs) |
|
|
|
return all_name_example_idx_pairs |
|
|
|
def train_epoch(self, datasets, batch_size=None, do_save_ckpt=True): |
|
"""Train the model through the entire dataset once. |
|
|
|
Args: |
|
datasets: A dict of data dicts. The dataset dict is simply a |
|
name(string)-> data dictionary mapping (See top of lfads.py). |
|
batch_size (optional): The batch_size to use |
|
do_save_ckpt (optional): Should the routine save a checkpoint on this |
|
training epoch? |
|
|
|
Returns: |
|
A tuple with 6 float values: |
|
(total cost of the epoch, epoch reconstruction cost, |
|
epoch kl cost, KL weight used this training epoch, |
|
total l2 cost on generator, and the corresponding weight). |
|
""" |
|
ops_to_eval = [self.cost, self.recon_cost, |
|
self.kl_cost, self.kl_weight, |
|
self.l2_cost, self.l2_weight, |
|
self.train_op] |
|
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind="train") |
|
|
|
total_cost = total_recon_cost = total_kl_cost = 0.0 |
|
|
|
epoch_size = len(collected_op_values) |
|
for op_values in collected_op_values: |
|
total_cost += op_values[0] |
|
total_recon_cost += op_values[1] |
|
total_kl_cost += op_values[2] |
|
|
|
kl_weight = collected_op_values[-1][3] |
|
l2_cost = collected_op_values[-1][4] |
|
l2_weight = collected_op_values[-1][5] |
|
|
|
epoch_total_cost = total_cost / epoch_size |
|
epoch_recon_cost = total_recon_cost / epoch_size |
|
epoch_kl_cost = total_kl_cost / epoch_size |
|
|
|
if do_save_ckpt: |
|
session = tf.get_default_session() |
|
checkpoint_path = os.path.join(self.hps.lfads_save_dir, |
|
self.hps.checkpoint_name + '.ckpt') |
|
self.seso_saver.save(session, checkpoint_path, |
|
global_step=self.train_step) |
|
|
|
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost, \ |
|
kl_weight, l2_cost, l2_weight |
|
|
|
|
|
def run_epoch(self, datasets, ops_to_eval, kind="train", batch_size=None, |
|
do_collect=True, keep_prob=None): |
|
"""Run the model through the entire dataset once. |
|
|
|
Args: |
|
datasets: A dict of data dicts. The dataset dict is simply a |
|
name(string)-> data dictionary mapping (See top of lfads.py). |
|
ops_to_eval: A list of tensorflow operations that will be evaluated in |
|
the tf.session.run() call. |
|
batch_size (optional): The batch_size to use |
|
do_collect (optional): Should the routine collect all session.run |
|
output as a list, and return it? |
|
keep_prob (optional): The dropout keep probability. |
|
|
|
Returns: |
|
A list of lists, the internal list is the return for the ops for each |
|
session.run() call. The outer list collects over the epoch. |
|
""" |
|
hps = self.hps |
|
all_name_example_idx_pairs = \ |
|
self.shuffle_and_flatten_datasets(datasets, kind) |
|
|
|
kind_data = kind + '_data' |
|
kind_ext_input = kind + '_ext_input' |
|
|
|
total_cost = total_recon_cost = total_kl_cost = 0.0 |
|
session = tf.get_default_session() |
|
epoch_size = len(all_name_example_idx_pairs) |
|
evaled_ops_list = [] |
|
for name, example_idxs in all_name_example_idx_pairs: |
|
data_dict = datasets[name] |
|
data_extxd = data_dict[kind_data] |
|
if hps.output_dist == 'poisson' and hps.temporal_spike_jitter_width > 0: |
|
data_extxd = self.shuffle_spikes_in_time(data_extxd) |
|
|
|
ext_input_extxi = data_dict[kind_ext_input] |
|
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, ext_input_extxi, |
|
example_idxs=example_idxs) |
|
|
|
feed_dict = self.build_feed_dict(name, data_bxtxd, ext_input_bxtxi, |
|
keep_prob=keep_prob) |
|
evaled_ops_np = session.run(ops_to_eval, feed_dict=feed_dict) |
|
if do_collect: |
|
evaled_ops_list.append(evaled_ops_np) |
|
|
|
return evaled_ops_list |
|
|
|
def summarize_all(self, datasets, summary_values): |
|
"""Plot and summarize stuff in tensorboard. |
|
|
|
Note that everything done in the current function is otherwise done on |
|
a single, randomly selected dataset (except for summary_values, which are |
|
passed in.) |
|
|
|
Args: |
|
datasets, the dictionary of datasets used in the study. |
|
summary_values: These summary values are created from the training loop, |
|
and so summarize the entire set of datasets. |
|
""" |
|
hps = self.hps |
|
tr_kl_cost = summary_values['tr_kl_cost'] |
|
tr_recon_cost = summary_values['tr_recon_cost'] |
|
tr_total_cost = summary_values['tr_total_cost'] |
|
kl_weight = summary_values['kl_weight'] |
|
l2_weight = summary_values['l2_weight'] |
|
l2_cost = summary_values['l2_cost'] |
|
has_any_valid_set = summary_values['has_any_valid_set'] |
|
i = summary_values['nepochs'] |
|
|
|
session = tf.get_default_session() |
|
train_summ, train_step = session.run([self.merged_train, |
|
self.train_step], |
|
feed_dict={self.l2_cost_ph:l2_cost, |
|
self.kl_cost_ph:tr_kl_cost, |
|
self.recon_cost_ph:tr_recon_cost, |
|
self.total_cost_ph:tr_total_cost}) |
|
self.writer.add_summary(train_summ, train_step) |
|
if has_any_valid_set: |
|
ev_kl_cost = summary_values['ev_kl_cost'] |
|
ev_recon_cost = summary_values['ev_recon_cost'] |
|
ev_total_cost = summary_values['ev_total_cost'] |
|
eval_summ = session.run(self.merged_valid, |
|
feed_dict={self.kl_cost_ph:ev_kl_cost, |
|
self.recon_cost_ph:ev_recon_cost, |
|
self.total_cost_ph:ev_total_cost}) |
|
self.writer.add_summary(eval_summ, train_step) |
|
print("Epoch:%d, step:%d (TRAIN, VALID): total: %.2f, %.2f\ |
|
recon: %.2f, %.2f, kl: %.2f, %.2f, l2: %.5f,\ |
|
kl weight: %.2f, l2 weight: %.2f" % \ |
|
(i, train_step, tr_total_cost, ev_total_cost, |
|
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost, |
|
l2_cost, kl_weight, l2_weight)) |
|
|
|
csv_outstr = "epoch,%d, step,%d, total,%.2f,%.2f, \ |
|
recon,%.2f,%.2f, kl,%.2f,%.2f, l2,%.5f, \ |
|
klweight,%.2f, l2weight,%.2f\n"% \ |
|
(i, train_step, tr_total_cost, ev_total_cost, |
|
tr_recon_cost, ev_recon_cost, tr_kl_cost, ev_kl_cost, |
|
l2_cost, kl_weight, l2_weight) |
|
|
|
else: |
|
print("Epoch:%d, step:%d TRAIN: total: %.2f recon: %.2f, kl: %.2f,\ |
|
l2: %.5f, kl weight: %.2f, l2 weight: %.2f" % \ |
|
(i, train_step, tr_total_cost, tr_recon_cost, tr_kl_cost, |
|
l2_cost, kl_weight, l2_weight)) |
|
csv_outstr = "epoch,%d, step,%d, total,%.2f, recon,%.2f, kl,%.2f, \ |
|
l2,%.5f, klweight,%.2f, l2weight,%.2f\n"% \ |
|
(i, train_step, tr_total_cost, tr_recon_cost, |
|
tr_kl_cost, l2_cost, kl_weight, l2_weight) |
|
|
|
if self.hps.csv_log: |
|
csv_file = os.path.join(self.hps.lfads_save_dir, self.hps.csv_log+'.csv') |
|
with open(csv_file, "a") as myfile: |
|
myfile.write(csv_outstr) |
|
|
|
|
|
def plot_single_example(self, datasets): |
|
"""Plot an image relating to a randomly chosen, specific example. We use |
|
posterior sample and average by taking one example, and filling a whole |
|
batch with that example, sample from the posterior, and then average the |
|
quantities. |
|
|
|
""" |
|
hps = self.hps |
|
all_data_names = datasets.keys() |
|
data_name = np.random.permutation(all_data_names)[0] |
|
data_dict = datasets[data_name] |
|
has_valid_set = True if data_dict['valid_data'] is not None else False |
|
cf = 1.0 |
|
|
|
|
|
E, _, _ = data_dict['train_data'].shape |
|
eidx = np.random.choice(E) |
|
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32) |
|
|
|
train_data_bxtxd, train_ext_input_bxtxi = \ |
|
self.get_batch(data_dict['train_data'], data_dict['train_ext_input'], |
|
example_idxs=example_idxs) |
|
|
|
truth_train_data_bxtxd = None |
|
if 'train_truth' in data_dict and data_dict['train_truth'] is not None: |
|
truth_train_data_bxtxd, _ = self.get_batch(data_dict['train_truth'], |
|
example_idxs=example_idxs) |
|
cf = data_dict['conversion_factor'] |
|
|
|
|
|
train_model_values = self.eval_model_runs_batch(data_name, |
|
train_data_bxtxd, |
|
train_ext_input_bxtxi, |
|
do_average_batch=False) |
|
|
|
train_step = train_model_values['train_steps'] |
|
feed_dict = self.build_feed_dict(data_name, train_data_bxtxd, |
|
train_ext_input_bxtxi, keep_prob=1.0) |
|
|
|
session = tf.get_default_session() |
|
generic_summ = session.run(self.merged_generic, feed_dict=feed_dict) |
|
self.writer.add_summary(generic_summ, train_step) |
|
|
|
valid_data_bxtxd = valid_model_values = valid_ext_input_bxtxi = None |
|
truth_valid_data_bxtxd = None |
|
if has_valid_set: |
|
E, _, _ = data_dict['valid_data'].shape |
|
eidx = np.random.choice(E) |
|
example_idxs = eidx * np.ones(hps.batch_size, dtype=np.int32) |
|
valid_data_bxtxd, valid_ext_input_bxtxi = \ |
|
self.get_batch(data_dict['valid_data'], |
|
data_dict['valid_ext_input'], |
|
example_idxs=example_idxs) |
|
if 'valid_truth' in data_dict and data_dict['valid_truth'] is not None: |
|
truth_valid_data_bxtxd, _ = self.get_batch(data_dict['valid_truth'], |
|
example_idxs=example_idxs) |
|
else: |
|
truth_valid_data_bxtxd = None |
|
|
|
|
|
valid_model_values = self.eval_model_runs_batch(data_name, |
|
valid_data_bxtxd, |
|
valid_ext_input_bxtxi, |
|
do_average_batch=False) |
|
|
|
example_image = plot_lfads(train_bxtxd=train_data_bxtxd, |
|
train_model_vals=train_model_values, |
|
train_ext_input_bxtxi=train_ext_input_bxtxi, |
|
train_truth_bxtxd=truth_train_data_bxtxd, |
|
valid_bxtxd=valid_data_bxtxd, |
|
valid_model_vals=valid_model_values, |
|
valid_ext_input_bxtxi=valid_ext_input_bxtxi, |
|
valid_truth_bxtxd=truth_valid_data_bxtxd, |
|
bidx=None, cf=cf, output_dist=hps.output_dist) |
|
example_image = np.expand_dims(example_image, axis=0) |
|
example_summ = session.run(self.merged_examples, |
|
feed_dict={self.example_image : example_image}) |
|
self.writer.add_summary(example_summ) |
|
|
|
def train_model(self, datasets): |
|
"""Train the model, print per-epoch information, and save checkpoints. |
|
|
|
Loop over training epochs. The function that actually does the |
|
training is train_epoch. This function iterates over the training |
|
data, one epoch at a time. The learning rate schedule is such |
|
that it will stay the same until the cost goes up in comparison to |
|
the last few values, then it will drop. |
|
|
|
Args: |
|
datasets: A dict of data dicts. The dataset dict is simply a |
|
name(string)-> data dictionary mapping (See top of lfads.py). |
|
""" |
|
hps = self.hps |
|
has_any_valid_set = False |
|
for data_dict in datasets.values(): |
|
if data_dict['valid_data'] is not None: |
|
has_any_valid_set = True |
|
break |
|
|
|
session = tf.get_default_session() |
|
lr = session.run(self.learning_rate) |
|
lr_stop = hps.learning_rate_stop |
|
i = -1 |
|
train_costs = [] |
|
valid_costs = [] |
|
ev_total_cost = ev_recon_cost = ev_kl_cost = 0.0 |
|
lowest_ev_cost = np.Inf |
|
while True: |
|
i += 1 |
|
do_save_ckpt = True if i % 10 ==0 else False |
|
tr_total_cost, tr_recon_cost, tr_kl_cost, kl_weight, l2_cost, l2_weight = \ |
|
self.train_epoch(datasets, do_save_ckpt=do_save_ckpt) |
|
|
|
|
|
|
|
|
|
if has_any_valid_set: |
|
ev_total_cost, ev_recon_cost, ev_kl_cost = \ |
|
self.eval_cost_epoch(datasets, kind='valid') |
|
valid_costs.append(ev_total_cost) |
|
|
|
|
|
|
|
n_lve = 1 |
|
run_avg_lve = np.mean(valid_costs[-n_lve:]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if kl_weight >= 1.0 and \ |
|
(l2_weight >= 1.0 or \ |
|
(self.hps.l2_gen_scale == 0.0 and self.hps.l2_con_scale == 0.0)) \ |
|
and (len(valid_costs) > n_lve and run_avg_lve < lowest_ev_cost): |
|
|
|
lowest_ev_cost = run_avg_lve |
|
checkpoint_path = os.path.join(self.hps.lfads_save_dir, |
|
self.hps.checkpoint_name + '_lve.ckpt') |
|
self.lve_saver.save(session, checkpoint_path, |
|
global_step=self.train_step, |
|
latest_filename='checkpoint_lve') |
|
|
|
|
|
values = {'nepochs':i, 'has_any_valid_set': has_any_valid_set, |
|
'tr_total_cost':tr_total_cost, 'ev_total_cost':ev_total_cost, |
|
'tr_recon_cost':tr_recon_cost, 'ev_recon_cost':ev_recon_cost, |
|
'tr_kl_cost':tr_kl_cost, 'ev_kl_cost':ev_kl_cost, |
|
'l2_weight':l2_weight, 'kl_weight':kl_weight, |
|
'l2_cost':l2_cost} |
|
self.summarize_all(datasets, values) |
|
self.plot_single_example(datasets) |
|
|
|
|
|
train_res = tr_total_cost |
|
n_lr = hps.learning_rate_n_to_compare |
|
if len(train_costs) > n_lr and train_res > np.max(train_costs[-n_lr:]): |
|
_ = session.run(self.learning_rate_decay_op) |
|
lr = session.run(self.learning_rate) |
|
print(" Decreasing learning rate to %f." % lr) |
|
|
|
train_costs.append(np.inf) |
|
else: |
|
train_costs.append(train_res) |
|
|
|
if lr < lr_stop: |
|
print("Stopping optimization based on learning rate criteria.") |
|
break |
|
|
|
def eval_cost_epoch(self, datasets, kind='train', ext_input_extxi=None, |
|
batch_size=None): |
|
"""Evaluate the cost of the epoch. |
|
|
|
Args: |
|
data_dict: The dictionary of data (training and validation) used for |
|
training and evaluation of the model, respectively. |
|
|
|
Returns: |
|
a 3 tuple of costs: |
|
(epoch total cost, epoch reconstruction cost, epoch KL cost) |
|
""" |
|
ops_to_eval = [self.cost, self.recon_cost, self.kl_cost] |
|
collected_op_values = self.run_epoch(datasets, ops_to_eval, kind=kind, |
|
keep_prob=1.0) |
|
|
|
total_cost = total_recon_cost = total_kl_cost = 0.0 |
|
|
|
epoch_size = len(collected_op_values) |
|
for op_values in collected_op_values: |
|
total_cost += op_values[0] |
|
total_recon_cost += op_values[1] |
|
total_kl_cost += op_values[2] |
|
|
|
epoch_total_cost = total_cost / epoch_size |
|
epoch_recon_cost = total_recon_cost / epoch_size |
|
epoch_kl_cost = total_kl_cost / epoch_size |
|
|
|
return epoch_total_cost, epoch_recon_cost, epoch_kl_cost |
|
|
|
def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None, |
|
do_eval_cost=False, do_average_batch=False): |
|
"""Returns all the goodies for the entire model, per batch. |
|
|
|
If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1 |
|
in which case this handles the padding and truncating automatically |
|
|
|
Args: |
|
data_name: The name of the data dict, to select which in/out matrices |
|
to use. |
|
data_bxtxd: Numpy array training data with shape: |
|
batch_size x # time steps x # dimensions |
|
ext_input_bxtxi: Numpy array training external input with shape: |
|
batch_size x # time steps x # external input dims |
|
do_eval_cost (optional): If true, the IWAE (Importance Weighted |
|
Autoencoder) log likeihood bound, instead of the VAE version. |
|
do_average_batch (optional): average over the batch, useful for getting |
|
good IWAE costs, and model outputs for a single data point. |
|
|
|
Returns: |
|
A dictionary with the outputs of the model decoder, namely: |
|
prior g0 mean, prior g0 variance, approx. posterior mean, approx |
|
posterior mean, the generator initial conditions, the control inputs (if |
|
enabled), the state of the generator, the factors, and the rates. |
|
""" |
|
session = tf.get_default_session() |
|
|
|
|
|
hps = self.hps |
|
batch_size = hps.batch_size |
|
E, _, _ = data_bxtxd.shape |
|
if E < hps.batch_size: |
|
data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)), |
|
mode='constant', constant_values=0) |
|
if ext_input_bxtxi is not None: |
|
ext_input_bxtxi = np.pad(ext_input_bxtxi, |
|
((0, hps.batch_size-E), (0, 0), (0, 0)), |
|
mode='constant', constant_values=0) |
|
|
|
feed_dict = self.build_feed_dict(data_name, data_bxtxd, |
|
ext_input_bxtxi, keep_prob=1.0) |
|
|
|
|
|
|
|
tf_vals = [self.gen_ics, self.gen_states, self.factors, |
|
self.output_dist_params] |
|
tf_vals.append(self.cost) |
|
tf_vals.append(self.nll_bound_vae) |
|
tf_vals.append(self.nll_bound_iwae) |
|
tf_vals.append(self.train_step) |
|
if self.hps.ic_dim > 0: |
|
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar, |
|
self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar] |
|
if self.hps.co_dim > 0: |
|
tf_vals.append(self.controller_outputs) |
|
tf_vals_flat, fidxs = flatten(tf_vals) |
|
|
|
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict) |
|
|
|
ff = 0 |
|
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 |
|
train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 |
|
if self.hps.ic_dim > 0: |
|
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1 |
|
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
if self.hps.co_dim > 0: |
|
controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
|
|
|
|
gen_ics = gen_ics[0] |
|
costs = costs[0] |
|
nll_bound_vaes = nll_bound_vaes[0] |
|
nll_bound_iwaes = nll_bound_iwaes[0] |
|
train_steps = train_steps[0] |
|
|
|
|
|
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states) |
|
factors = list_t_bxn_to_tensor_bxtxn(factors) |
|
out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params) |
|
if self.hps.ic_dim > 0: |
|
|
|
prior_g0_mean = prior_g0_mean[0] |
|
prior_g0_logvar = prior_g0_logvar[0] |
|
post_g0_mean = post_g0_mean[0] |
|
post_g0_logvar = post_g0_logvar[0] |
|
if self.hps.co_dim > 0: |
|
controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs) |
|
|
|
|
|
if E < hps.batch_size: |
|
idx = np.arange(E) |
|
gen_ics = gen_ics[idx, :] |
|
gen_states = gen_states[idx, :] |
|
factors = factors[idx, :, :] |
|
out_dist_params = out_dist_params[idx, :, :] |
|
if self.hps.ic_dim > 0: |
|
prior_g0_mean = prior_g0_mean[idx, :] |
|
prior_g0_logvar = prior_g0_logvar[idx, :] |
|
post_g0_mean = post_g0_mean[idx, :] |
|
post_g0_logvar = post_g0_logvar[idx, :] |
|
if self.hps.co_dim > 0: |
|
controller_outputs = controller_outputs[idx, :, :] |
|
|
|
if do_average_batch: |
|
gen_ics = np.mean(gen_ics, axis=0) |
|
gen_states = np.mean(gen_states, axis=0) |
|
factors = np.mean(factors, axis=0) |
|
out_dist_params = np.mean(out_dist_params, axis=0) |
|
if self.hps.ic_dim > 0: |
|
prior_g0_mean = np.mean(prior_g0_mean, axis=0) |
|
prior_g0_logvar = np.mean(prior_g0_logvar, axis=0) |
|
post_g0_mean = np.mean(post_g0_mean, axis=0) |
|
post_g0_logvar = np.mean(post_g0_logvar, axis=0) |
|
if self.hps.co_dim > 0: |
|
controller_outputs = np.mean(controller_outputs, axis=0) |
|
|
|
model_vals = {} |
|
model_vals['gen_ics'] = gen_ics |
|
model_vals['gen_states'] = gen_states |
|
model_vals['factors'] = factors |
|
model_vals['output_dist_params'] = out_dist_params |
|
model_vals['costs'] = costs |
|
model_vals['nll_bound_vaes'] = nll_bound_vaes |
|
model_vals['nll_bound_iwaes'] = nll_bound_iwaes |
|
model_vals['train_steps'] = train_steps |
|
if self.hps.ic_dim > 0: |
|
model_vals['prior_g0_mean'] = prior_g0_mean |
|
model_vals['prior_g0_logvar'] = prior_g0_logvar |
|
model_vals['post_g0_mean'] = post_g0_mean |
|
model_vals['post_g0_logvar'] = post_g0_logvar |
|
if self.hps.co_dim > 0: |
|
model_vals['controller_outputs'] = controller_outputs |
|
|
|
return model_vals |
|
|
|
def eval_model_runs_avg_epoch(self, data_name, data_extxd, |
|
ext_input_extxi=None): |
|
"""Returns all the expected value for goodies for the entire model. |
|
|
|
The expected value is taken over hidden (z) variables, namely the initial |
|
conditions and the control inputs. The expected value is approximate, and |
|
accomplished via sampling (batch_size) samples for every examples. |
|
|
|
Args: |
|
data_name: The name of the data dict, to select which in/out matrices |
|
to use. |
|
data_extxd: Numpy array training data with shape: |
|
# examples x # time steps x # dimensions |
|
ext_input_extxi (optional): Numpy array training external input with |
|
shape: # examples x # time steps x # external input dims |
|
|
|
Returns: |
|
A dictionary with the averaged outputs of the model decoder, namely: |
|
prior g0 mean, prior g0 variance, approx. posterior mean, approx |
|
posterior mean, the generator initial conditions, the control inputs (if |
|
enabled), the state of the generator, the factors, and the output |
|
distribution parameters, e.g. (rates or mean and variances). |
|
""" |
|
hps = self.hps |
|
batch_size = hps.batch_size |
|
E, T, D = data_extxd.shape |
|
E_to_process = hps.ps_nexamples_to_process |
|
if E_to_process > E: |
|
E_to_process = E |
|
|
|
if hps.ic_dim > 0: |
|
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
|
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
|
post_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
|
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
|
|
|
if hps.co_dim > 0: |
|
controller_outputs = np.zeros([E_to_process, T, hps.co_dim]) |
|
gen_ics = np.zeros([E_to_process, hps.gen_dim]) |
|
gen_states = np.zeros([E_to_process, T, hps.gen_dim]) |
|
factors = np.zeros([E_to_process, T, hps.factors_dim]) |
|
|
|
if hps.output_dist == 'poisson': |
|
out_dist_params = np.zeros([E_to_process, T, D]) |
|
elif hps.output_dist == 'gaussian': |
|
out_dist_params = np.zeros([E_to_process, T, D+D]) |
|
else: |
|
assert False, "NIY" |
|
|
|
costs = np.zeros(E_to_process) |
|
nll_bound_vaes = np.zeros(E_to_process) |
|
nll_bound_iwaes = np.zeros(E_to_process) |
|
train_steps = np.zeros(E_to_process) |
|
for es_idx in range(E_to_process): |
|
print("Running %d of %d." % (es_idx+1, E_to_process)) |
|
example_idxs = es_idx * np.ones(batch_size, dtype=np.int32) |
|
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, |
|
ext_input_extxi, |
|
batch_size=batch_size, |
|
example_idxs=example_idxs) |
|
model_values = self.eval_model_runs_batch(data_name, data_bxtxd, |
|
ext_input_bxtxi, |
|
do_eval_cost=True, |
|
do_average_batch=True) |
|
|
|
if self.hps.ic_dim > 0: |
|
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean'] |
|
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar'] |
|
post_g0_mean[es_idx,:] = model_values['post_g0_mean'] |
|
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar'] |
|
gen_ics[es_idx,:] = model_values['gen_ics'] |
|
|
|
if self.hps.co_dim > 0: |
|
controller_outputs[es_idx,:,:] = model_values['controller_outputs'] |
|
gen_states[es_idx,:,:] = model_values['gen_states'] |
|
factors[es_idx,:,:] = model_values['factors'] |
|
out_dist_params[es_idx,:,:] = model_values['output_dist_params'] |
|
costs[es_idx] = model_values['costs'] |
|
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes'] |
|
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes'] |
|
train_steps[es_idx] = model_values['train_steps'] |
|
print('bound nll(vae): %.3f, bound nll(iwae): %.3f' \ |
|
% (nll_bound_vaes[es_idx], nll_bound_iwaes[es_idx])) |
|
|
|
model_runs = {} |
|
if self.hps.ic_dim > 0: |
|
model_runs['prior_g0_mean'] = prior_g0_mean |
|
model_runs['prior_g0_logvar'] = prior_g0_logvar |
|
model_runs['post_g0_mean'] = post_g0_mean |
|
model_runs['post_g0_logvar'] = post_g0_logvar |
|
model_runs['gen_ics'] = gen_ics |
|
|
|
if self.hps.co_dim > 0: |
|
model_runs['controller_outputs'] = controller_outputs |
|
model_runs['gen_states'] = gen_states |
|
model_runs['factors'] = factors |
|
model_runs['output_dist_params'] = out_dist_params |
|
model_runs['costs'] = costs |
|
model_runs['nll_bound_vaes'] = nll_bound_vaes |
|
model_runs['nll_bound_iwaes'] = nll_bound_iwaes |
|
model_runs['train_steps'] = train_steps |
|
return model_runs |
|
|
|
def eval_model_runs_push_mean(self, data_name, data_extxd, |
|
ext_input_extxi=None): |
|
"""Returns values of interest for the model by pushing the means through |
|
|
|
The mean values for both initial conditions and the control inputs are |
|
pushed through the model instead of sampling (as is done in |
|
eval_model_runs_avg_epoch). |
|
This is a quick and approximate version of estimating these values instead |
|
of sampling from the posterior many times and then averaging those values of |
|
interest. |
|
|
|
Internally, a total of batch_size trials are run through the model at once. |
|
|
|
Args: |
|
data_name: The name of the data dict, to select which in/out matrices |
|
to use. |
|
data_extxd: Numpy array training data with shape: |
|
# examples x # time steps x # dimensions |
|
ext_input_extxi (optional): Numpy array training external input with |
|
shape: # examples x # time steps x # external input dims |
|
|
|
Returns: |
|
A dictionary with the estimated outputs of the model decoder, namely: |
|
prior g0 mean, prior g0 variance, approx. posterior mean, approx |
|
posterior mean, the generator initial conditions, the control inputs (if |
|
enabled), the state of the generator, the factors, and the output |
|
distribution parameters, e.g. (rates or mean and variances). |
|
""" |
|
hps = self.hps |
|
batch_size = hps.batch_size |
|
E, T, D = data_extxd.shape |
|
E_to_process = hps.ps_nexamples_to_process |
|
if E_to_process > E: |
|
print("Setting number of posterior samples to process to : ", E) |
|
E_to_process = E |
|
|
|
if hps.ic_dim > 0: |
|
prior_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
|
prior_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
|
post_g0_mean = np.zeros([E_to_process, hps.ic_dim]) |
|
post_g0_logvar = np.zeros([E_to_process, hps.ic_dim]) |
|
|
|
if hps.co_dim > 0: |
|
controller_outputs = np.zeros([E_to_process, T, hps.co_dim]) |
|
gen_ics = np.zeros([E_to_process, hps.gen_dim]) |
|
gen_states = np.zeros([E_to_process, T, hps.gen_dim]) |
|
factors = np.zeros([E_to_process, T, hps.factors_dim]) |
|
|
|
if hps.output_dist == 'poisson': |
|
out_dist_params = np.zeros([E_to_process, T, D]) |
|
elif hps.output_dist == 'gaussian': |
|
out_dist_params = np.zeros([E_to_process, T, D+D]) |
|
else: |
|
assert False, "NIY" |
|
|
|
costs = np.zeros(E_to_process) |
|
nll_bound_vaes = np.zeros(E_to_process) |
|
nll_bound_iwaes = np.zeros(E_to_process) |
|
train_steps = np.zeros(E_to_process) |
|
|
|
|
|
|
|
|
|
def trial_batches(N, per): |
|
for i in range(0, N, per): |
|
yield np.arange(i, min(i+per, N), dtype=np.int32) |
|
|
|
for batch_idx, es_idx in enumerate(trial_batches(E_to_process, |
|
hps.batch_size)): |
|
print("Running trial batch %d with %d trials" % (batch_idx+1, |
|
len(es_idx))) |
|
data_bxtxd, ext_input_bxtxi = self.get_batch(data_extxd, |
|
ext_input_extxi, |
|
batch_size=batch_size, |
|
example_idxs=es_idx) |
|
model_values = self.eval_model_runs_batch(data_name, data_bxtxd, |
|
ext_input_bxtxi, |
|
do_eval_cost=True, |
|
do_average_batch=False) |
|
|
|
if self.hps.ic_dim > 0: |
|
prior_g0_mean[es_idx,:] = model_values['prior_g0_mean'] |
|
prior_g0_logvar[es_idx,:] = model_values['prior_g0_logvar'] |
|
post_g0_mean[es_idx,:] = model_values['post_g0_mean'] |
|
post_g0_logvar[es_idx,:] = model_values['post_g0_logvar'] |
|
gen_ics[es_idx,:] = model_values['gen_ics'] |
|
|
|
if self.hps.co_dim > 0: |
|
controller_outputs[es_idx,:,:] = model_values['controller_outputs'] |
|
gen_states[es_idx,:,:] = model_values['gen_states'] |
|
factors[es_idx,:,:] = model_values['factors'] |
|
out_dist_params[es_idx,:,:] = model_values['output_dist_params'] |
|
|
|
|
|
|
|
|
|
costs[es_idx] = model_values['costs'] |
|
nll_bound_vaes[es_idx] = model_values['nll_bound_vaes'] |
|
nll_bound_iwaes[es_idx] = model_values['nll_bound_iwaes'] |
|
|
|
train_steps[es_idx] = model_values['train_steps'] |
|
|
|
model_runs = {} |
|
if self.hps.ic_dim > 0: |
|
model_runs['prior_g0_mean'] = prior_g0_mean |
|
model_runs['prior_g0_logvar'] = prior_g0_logvar |
|
model_runs['post_g0_mean'] = post_g0_mean |
|
model_runs['post_g0_logvar'] = post_g0_logvar |
|
model_runs['gen_ics'] = gen_ics |
|
|
|
if self.hps.co_dim > 0: |
|
model_runs['controller_outputs'] = controller_outputs |
|
model_runs['gen_states'] = gen_states |
|
model_runs['factors'] = factors |
|
model_runs['output_dist_params'] = out_dist_params |
|
|
|
|
|
|
|
model_runs['costs'] = costs |
|
model_runs['nll_bound_vaes'] = nll_bound_vaes |
|
model_runs['nll_bound_iwaes'] = nll_bound_iwaes |
|
model_runs['train_steps'] = train_steps |
|
return model_runs |
|
|
|
def write_model_runs(self, datasets, output_fname=None, push_mean=False): |
|
"""Run the model on the data in data_dict, and save the computed values. |
|
|
|
LFADS generates a number of outputs for each examples, and these are all |
|
saved. They are: |
|
The mean and variance of the prior of g0. |
|
The mean and variance of approximate posterior of g0. |
|
The control inputs (if enabled) |
|
The initial conditions, g0, for all examples. |
|
The generator states for all time. |
|
The factors for all time. |
|
The output distribution parameters (e.g. rates) for all time. |
|
|
|
Args: |
|
datasets: a dictionary of named data_dictionaries, see top of lfads.py |
|
output_fname: a file name stem for the output files. |
|
push_mean: if False (default), generates batch_size samples for each trial |
|
and averages the results. if True, runs each trial once without noise, |
|
pushing the posterior mean initial conditions and control inputs through |
|
the trained model. False is used for posterior_sample_and_average, True |
|
is used for posterior_push_mean. |
|
""" |
|
hps = self.hps |
|
kind = hps.kind |
|
|
|
for data_name, data_dict in datasets.items(): |
|
data_tuple = [('train', data_dict['train_data'], |
|
data_dict['train_ext_input']), |
|
('valid', data_dict['valid_data'], |
|
data_dict['valid_ext_input'])] |
|
for data_kind, data_extxd, ext_input_extxi in data_tuple: |
|
if not output_fname: |
|
fname = "model_runs_" + data_name + '_' + data_kind + '_' + kind |
|
else: |
|
fname = output_fname + data_name + '_' + data_kind + '_' + kind |
|
|
|
print("Writing data for %s data and kind %s." % (data_name, data_kind)) |
|
if push_mean: |
|
model_runs = self.eval_model_runs_push_mean(data_name, data_extxd, |
|
ext_input_extxi) |
|
else: |
|
model_runs = self.eval_model_runs_avg_epoch(data_name, data_extxd, |
|
ext_input_extxi) |
|
full_fname = os.path.join(hps.lfads_save_dir, fname) |
|
write_data(full_fname, model_runs, compression='gzip') |
|
print("Done.") |
|
|
|
def write_model_samples(self, dataset_name, output_fname=None): |
|
"""Use the prior distribution to generate batch_size number of samples |
|
from the model. |
|
|
|
LFADS generates a number of outputs for each sample, and these are all |
|
saved. They are: |
|
The mean and variance of the prior of g0. |
|
The control inputs (if enabled) |
|
The initial conditions, g0, for all examples. |
|
The generator states for all time. |
|
The factors for all time. |
|
The output distribution parameters (e.g. rates) for all time. |
|
|
|
Args: |
|
dataset_name: The name of the dataset to grab the factors -> rates |
|
alignment matrices from. |
|
output_fname: The name of the file in which to save the generated |
|
samples. |
|
""" |
|
hps = self.hps |
|
batch_size = hps.batch_size |
|
|
|
print("Generating %d samples" % (batch_size)) |
|
tf_vals = [self.factors, self.gen_states, self.gen_ics, |
|
self.cost, self.output_dist_params] |
|
if hps.ic_dim > 0: |
|
tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar] |
|
if hps.co_dim > 0: |
|
tf_vals += [self.prior_zs_ar_con.samples_t] |
|
tf_vals_flat, fidxs = flatten(tf_vals) |
|
|
|
session = tf.get_default_session() |
|
feed_dict = {} |
|
feed_dict[self.dataName] = dataset_name |
|
feed_dict[self.keep_prob] = 1.0 |
|
|
|
np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict) |
|
|
|
ff = 0 |
|
factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
output_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
if hps.ic_dim > 0: |
|
prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
if hps.co_dim > 0: |
|
prior_zs_ar_con = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1 |
|
|
|
|
|
gen_ics = gen_ics[0] |
|
costs = costs[0] |
|
|
|
|
|
gen_states = list_t_bxn_to_tensor_bxtxn(gen_states) |
|
factors = list_t_bxn_to_tensor_bxtxn(factors) |
|
output_dist_params = list_t_bxn_to_tensor_bxtxn(output_dist_params) |
|
if hps.ic_dim > 0: |
|
prior_g0_mean = prior_g0_mean[0] |
|
prior_g0_logvar = prior_g0_logvar[0] |
|
if hps.co_dim > 0: |
|
prior_zs_ar_con = list_t_bxn_to_tensor_bxtxn(prior_zs_ar_con) |
|
|
|
model_vals = {} |
|
model_vals['gen_ics'] = gen_ics |
|
model_vals['gen_states'] = gen_states |
|
model_vals['factors'] = factors |
|
model_vals['output_dist_params'] = output_dist_params |
|
model_vals['costs'] = costs.reshape(1) |
|
if hps.ic_dim > 0: |
|
model_vals['prior_g0_mean'] = prior_g0_mean |
|
model_vals['prior_g0_logvar'] = prior_g0_logvar |
|
if hps.co_dim > 0: |
|
model_vals['prior_zs_ar_con'] = prior_zs_ar_con |
|
|
|
full_fname = os.path.join(hps.lfads_save_dir, output_fname) |
|
write_data(full_fname, model_vals, compression='gzip') |
|
print("Done.") |
|
|
|
@staticmethod |
|
def eval_model_parameters(use_nested=True, include_strs=None): |
|
"""Evaluate and return all of the TF variables in the model. |
|
|
|
Args: |
|
use_nested (optional): For returning values, use a nested dictoinary, based |
|
on variable scoping, or return all variables in a flat dictionary. |
|
include_strs (optional): A list of strings to use as a filter, to reduce the |
|
number of variables returned. A variable name must contain at least one |
|
string in include_strs as a sub-string in order to be returned. |
|
|
|
Returns: |
|
The parameters of the model. This can be in a flat |
|
dictionary, or a nested dictionary, where the nesting is by variable |
|
scope. |
|
""" |
|
all_tf_vars = tf.global_variables() |
|
session = tf.get_default_session() |
|
all_tf_vars_eval = session.run(all_tf_vars) |
|
vars_dict = {} |
|
strs = ["LFADS"] |
|
if include_strs: |
|
strs += include_strs |
|
|
|
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)): |
|
if any(s in include_strs for s in var.name): |
|
if not isinstance(var_eval, np.ndarray): |
|
print(var.name, """ is not numpy array, saving as numpy array |
|
with value: """, var_eval, type(var_eval)) |
|
e = np.array(var_eval) |
|
print(e, type(e)) |
|
else: |
|
e = var_eval |
|
vars_dict[var.name] = e |
|
|
|
if not use_nested: |
|
return vars_dict |
|
|
|
var_names = vars_dict.keys() |
|
nested_vars_dict = {} |
|
current_dict = nested_vars_dict |
|
for v, var_name in enumerate(var_names): |
|
var_split_name_list = var_name.split('/') |
|
split_name_list_len = len(var_split_name_list) |
|
current_dict = nested_vars_dict |
|
for p, part in enumerate(var_split_name_list): |
|
if p < split_name_list_len - 1: |
|
if part in current_dict: |
|
current_dict = current_dict[part] |
|
else: |
|
current_dict[part] = {} |
|
current_dict = current_dict[part] |
|
else: |
|
current_dict[part] = vars_dict[var_name] |
|
|
|
return nested_vars_dict |
|
|
|
@staticmethod |
|
def spikify_rates(rates_bxtxd): |
|
"""Randomly spikify underlying rates according a Poisson distribution |
|
|
|
Args: |
|
rates_bxtxd: a numpy tensor with shape: |
|
|
|
Returns: |
|
A numpy array with the same shape as rates_bxtxd, but with the event |
|
counts. |
|
""" |
|
|
|
B,T,N = rates_bxtxd.shape |
|
assert all([B > 0, N > 0]), "problems" |
|
|
|
|
|
spikes_bxtxd = np.zeros([B,T,N], dtype=np.int32) |
|
for b in range(B): |
|
for t in range(T): |
|
for n in range(N): |
|
rate = rates_bxtxd[b,t,n] |
|
count = np.random.poisson(rate) |
|
spikes_bxtxd[b,t,n] = count |
|
|
|
return spikes_bxtxd |
|
|