|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import h5py |
|
import numpy as np |
|
import os |
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
from utils import write_datasets |
|
from synthetic_data_utils import normalize_rates |
|
from synthetic_data_utils import get_train_n_valid_inds, nparray_and_transpose |
|
from synthetic_data_utils import spikify_data, split_list_by_inds |
|
|
|
DATA_DIR = "rnn_synth_data_v1.0" |
|
|
|
flags = tf.app.flags |
|
flags.DEFINE_string("save_dir", "/tmp/" + DATA_DIR + "/", |
|
"Directory for saving data.") |
|
flags.DEFINE_string("datafile_name", "itb_rnn", |
|
"Name of data file for input case.") |
|
flags.DEFINE_integer("synth_data_seed", 5, "Random seed for RNN generation.") |
|
flags.DEFINE_float("T", 1.0, "Time in seconds to generate.") |
|
flags.DEFINE_integer("C", 800, "Number of conditions") |
|
flags.DEFINE_integer("N", 50, "Number of units for the RNN") |
|
flags.DEFINE_float("train_percentage", 4.0/5.0, |
|
"Percentage of train vs validation trials") |
|
flags.DEFINE_integer("nreplications", 5, |
|
"Number of spikifications of the same underlying rates.") |
|
flags.DEFINE_float("tau", 0.025, "Time constant of RNN") |
|
flags.DEFINE_float("dt", 0.010, "Time bin") |
|
flags.DEFINE_float("max_firing_rate", 30.0, |
|
"Map 1.0 of RNN to a spikes per second") |
|
flags.DEFINE_float("u_std", 0.25, |
|
"Std dev of input to integration to bound model") |
|
flags.DEFINE_string("checkpoint_path", "SAMPLE_CHECKPOINT", |
|
"""Path to directory with checkpoints of model |
|
trained on integration to bound task. Currently this |
|
is a placeholder which tells the code to grab the |
|
checkpoint that is provided with the code |
|
(in /trained_itb/..). If you have your own checkpoint |
|
you would like to restore, you would point it to |
|
that path.""") |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
class IntegrationToBoundModel: |
|
def __init__(self, N): |
|
scale = 0.8 / float(N**0.5) |
|
self.N = N |
|
self.Wh_nxn = tf.Variable(tf.random_normal([N, N], stddev=scale)) |
|
self.b_1xn = tf.Variable(tf.zeros([1, N])) |
|
self.Bu_1xn = tf.Variable(tf.zeros([1, N])) |
|
self.Wro_nxo = tf.Variable(tf.random_normal([N, 1], stddev=scale)) |
|
self.bro_o = tf.Variable(tf.zeros([1])) |
|
|
|
def call(self, h_tm1_bxn, u_bx1): |
|
act_t_bxn = tf.matmul(h_tm1_bxn, self.Wh_nxn) + self.b_1xn + u_bx1 * self.Bu_1xn |
|
h_t_bxn = tf.nn.tanh(act_t_bxn) |
|
z_t = tf.nn.xw_plus_b(h_t_bxn, self.Wro_nxo, self.bro_o) |
|
return z_t, h_t_bxn |
|
|
|
def get_data_batch(batch_size, T, rng, u_std): |
|
u_bxt = rng.randn(batch_size, T) * u_std |
|
running_sum_b = np.zeros([batch_size]) |
|
labels_bxt = np.zeros([batch_size, T]) |
|
for t in xrange(T): |
|
running_sum_b += u_bxt[:, t] |
|
labels_bxt[:, t] += running_sum_b |
|
labels_bxt = np.clip(labels_bxt, -1, 1) |
|
return u_bxt, labels_bxt |
|
|
|
|
|
rng = np.random.RandomState(seed=FLAGS.synth_data_seed) |
|
u_rng = np.random.RandomState(seed=FLAGS.synth_data_seed+1) |
|
T = FLAGS.T |
|
C = FLAGS.C |
|
N = FLAGS.N |
|
nreplications = FLAGS.nreplications |
|
E = nreplications * C |
|
train_percentage = FLAGS.train_percentage |
|
ntimesteps = int(T / FLAGS.dt) |
|
batch_size = 1 |
|
|
|
model = IntegrationToBoundModel(N) |
|
inputs_ph_t = [tf.placeholder(tf.float32, |
|
shape=[None, 1]) for _ in range(ntimesteps)] |
|
state = tf.zeros([batch_size, N]) |
|
saver = tf.train.Saver() |
|
|
|
P_nxn = rng.randn(N,N) / np.sqrt(N) |
|
|
|
|
|
outputs_t = [] |
|
states_t = [] |
|
|
|
for inp in inputs_ph_t: |
|
output, state = model.call(state, inp) |
|
outputs_t.append(output) |
|
states_t.append(state) |
|
|
|
with tf.Session() as sess: |
|
|
|
if FLAGS.checkpoint_path == "SAMPLE_CHECKPOINT": |
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
model_checkpoint_path = os.path.join(dir_path, "trained_itb/model-65000") |
|
else: |
|
model_checkpoint_path = FLAGS.checkpoint_path |
|
try: |
|
saver.restore(sess, model_checkpoint_path) |
|
print ('Model restored from', model_checkpoint_path) |
|
except: |
|
assert False, ("No checkpoints to restore from, is the path %s correct?" |
|
%model_checkpoint_path) |
|
|
|
|
|
data_e = [] |
|
u_e = [] |
|
outs_e = [] |
|
for c in range(C): |
|
u_1xt, outs_1xt = get_data_batch(batch_size, ntimesteps, u_rng, FLAGS.u_std) |
|
|
|
feed_dict = {} |
|
for t in xrange(ntimesteps): |
|
feed_dict[inputs_ph_t[t]] = np.reshape(u_1xt[:,t], (batch_size,-1)) |
|
|
|
states_t_bxn, outputs_t_bxn = sess.run([states_t, outputs_t], |
|
feed_dict=feed_dict) |
|
states_nxt = np.transpose(np.squeeze(np.asarray(states_t_bxn))) |
|
outputs_t_bxn = np.squeeze(np.asarray(outputs_t_bxn)) |
|
r_sxt = np.dot(P_nxn, states_nxt) |
|
|
|
for s in xrange(nreplications): |
|
data_e.append(r_sxt) |
|
u_e.append(u_1xt) |
|
outs_e.append(outputs_t_bxn) |
|
|
|
truth_data_e = normalize_rates(data_e, E, N) |
|
|
|
spiking_data_e = spikify_data(truth_data_e, rng, dt=FLAGS.dt, |
|
max_firing_rate=FLAGS.max_firing_rate) |
|
train_inds, valid_inds = get_train_n_valid_inds(E, train_percentage, |
|
nreplications) |
|
|
|
data_train_truth, data_valid_truth = split_list_by_inds(truth_data_e, |
|
train_inds, |
|
valid_inds) |
|
data_train_spiking, data_valid_spiking = split_list_by_inds(spiking_data_e, |
|
train_inds, |
|
valid_inds) |
|
|
|
data_train_truth = nparray_and_transpose(data_train_truth) |
|
data_valid_truth = nparray_and_transpose(data_valid_truth) |
|
data_train_spiking = nparray_and_transpose(data_train_spiking) |
|
data_valid_spiking = nparray_and_transpose(data_valid_spiking) |
|
|
|
|
|
train_inputs_u, valid_inputs_u = split_list_by_inds(u_e, |
|
train_inds, |
|
valid_inds) |
|
train_inputs_u = nparray_and_transpose(train_inputs_u) |
|
valid_inputs_u = nparray_and_transpose(valid_inputs_u) |
|
|
|
|
|
train_outputs_u, valid_outputs_u = split_list_by_inds(outs_e, |
|
train_inds, |
|
valid_inds) |
|
train_outputs_u = np.array(train_outputs_u) |
|
valid_outputs_u = np.array(valid_outputs_u) |
|
|
|
|
|
data = { 'train_truth': data_train_truth, |
|
'valid_truth': data_valid_truth, |
|
'train_data' : data_train_spiking, |
|
'valid_data' : data_valid_spiking, |
|
'train_percentage' : train_percentage, |
|
'nreplications' : nreplications, |
|
'dt' : FLAGS.dt, |
|
'u_std' : FLAGS.u_std, |
|
'max_firing_rate': FLAGS.max_firing_rate, |
|
'train_inputs_u': train_inputs_u, |
|
'valid_inputs_u': valid_inputs_u, |
|
'train_outputs_u': train_outputs_u, |
|
'valid_outputs_u': valid_outputs_u, |
|
'conversion_factor' : FLAGS.max_firing_rate/(1.0/FLAGS.dt) } |
|
|
|
|
|
datasets = {} |
|
dataset_name = 'dataset_N' + str(N) |
|
datasets[dataset_name] = data |
|
|
|
|
|
write_datasets(FLAGS.save_dir, FLAGS.datafile_name, datasets) |
|
print ('Saved to ', os.path.join(FLAGS.save_dir, |
|
FLAGS.datafile_name + '_' + dataset_name)) |
|
|