NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script for running fivo"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import defaultdict
import numpy as np
import tensorflow as tf
import bounds
import data
import models
import summary_utils as summ
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for the data generating process. Same seed "
"-> same data generating process and initialization.")
tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"],
"The bound to optimize.")
tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"],
"The model to use.")
tf.app.flags.DEFINE_enum("q_type", "normal",
["normal", "simple_mean", "prev_state", "observation"],
"The parameterization to use for q")
tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"],
"The type of prior.")
tf.app.flags.DEFINE_boolean("train_p", True,
"If false, do not train the model p.")
tf.app.flags.DEFINE_integer("state_size", 1,
"The dimensionality of the state space.")
tf.app.flags.DEFINE_float("variance", 1.0,
"The variance of the data generating process.")
tf.app.flags.DEFINE_boolean("use_bs", True,
"If False, initialize all bs to 0.")
tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5,
"The weight assigned to the positive mode of the prior in "
"both the data generating process and p.")
tf.app.flags.DEFINE_float("bimodal_prior_mean", None,
"If supplied, sets the mean of the 2 modes of the prior to "
"be 1 and -1 times the supplied value. This is for both the "
"data generating process and p.")
tf.app.flags.DEFINE_float("fixed_observation", None,
"If supplied, fix the observation to a constant value in the"
" data generating process only.")
tf.app.flags.DEFINE_float("r_sigma_init", 1.,
"Value to initialize variance of r to.")
tf.app.flags.DEFINE_enum("observation_type",
models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES,
"The type of observation for the long chain model.")
tf.app.flags.DEFINE_enum("transition_type",
models.STANDARD_TRANSITION, models.TRANSITION_TYPES,
"The type of transition for the long chain model.")
tf.app.flags.DEFINE_float("observation_variance", None,
"The variance of the observation. Defaults to 'variance'")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"Number of timesteps in the sequence.")
tf.app.flags.DEFINE_integer("num_observations", 1,
"The number of observations.")
tf.app.flags.DEFINE_integer("steps_per_observation", 5,
"The number of timesteps between each observation.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"The number of examples per batch.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number particles to use.")
tf.app.flags.DEFINE_integer("num_eval_samples", 512,
"The batch size and # of particles to use for eval.")
tf.app.flags.DEFINE_string("resampling", "always",
"How to resample. Accepts 'always','never', or a "
"comma-separated list of booleans like 'true,true,false'.")
tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial",
"stratified",
"systematic",
"relaxed-logblend",
"relaxed-stateblend",
"relaxed-linearblend",
"relaxed-stateblend-st",],
"Type of resampling method to use.")
tf.app.flags.DEFINE_boolean("use_resampling_grads", True,
"Whether or not to use resampling grads to optimize FIVO."
"Disabled automatically if resampling_method=relaxed.")
tf.app.flags.DEFINE_boolean("disable_r", False,
"If false, r is not used for fivo-aux and is set to zeros.")
tf.app.flags.DEFINE_float("learning_rate", 1e-4,
"The learning rate to use for ADAM or SGD.")
tf.app.flags.DEFINE_integer("decay_steps", 25000,
"The number of steps before the learning rate is halved.")
tf.app.flags.DEFINE_integer("max_steps", int(1e6),
"The number of steps to run training for.")
tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux",
"Directory for summaries and checkpoints.")
tf.app.flags.DEFINE_integer("summarize_every", int(1e3),
"The number of steps between each evaluation.")
FLAGS = tf.app.flags.FLAGS
def combine_grad_lists(grad_lists):
# grads is num_losses by num_variables.
# each list could have different variables.
# for each variable, sum the grads across all losses.
grads_dict = defaultdict(list)
var_dict = {}
for grad_list in grad_lists:
for grad, var in grad_list:
if grad is not None:
grads_dict[var.name].append(grad)
var_dict[var.name] = var
final_grads = []
for var_name, var in var_dict.iteritems():
grads = grads_dict[var_name]
if len(grads) > 0:
tf.logging.info("Var %s has combined grads from %s." %
(var_name, [g.name for g in grads]))
grad = tf.reduce_sum(grads, axis=0)
else:
tf.logging.info("Var %s has no grads" % var_name)
grad = None
final_grads.append((grad, var))
return final_grads
def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps):
for l in losses:
assert isinstance(l, bounds.Loss)
lr = tf.train.exponential_decay(
learning_rate, global_step, lr_decay_steps, 0.5, staircase=False)
tf.summary.scalar("learning_rate", lr)
opt = tf.train.AdamOptimizer(lr)
ema_ops = []
grads = []
for loss_name, loss, loss_var_collection in losses:
tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" %
(loss_name, loss_var_collection))
g = opt.compute_gradients(loss,
var_list=tf.get_collection(loss_var_collection))
ema_ops.append(summ.summarize_grads(g, loss_name))
grads.append(g)
all_grads = combine_grad_lists(grads)
apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads_op]):
train_op = tf.group(*ema_ops)
return train_op
def add_check_numerics_ops():
check_op = []
for op in tf.get_default_graph().get_operations():
bad = ["logits/Log", "sample/Reshape", "log_prob/mul",
"log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape",
"entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"]
if all([x not in op.name for x in bad]):
for output in op.outputs:
if output.dtype in [tf.float16, tf.float32, tf.float64]:
if op._get_control_flow_context() is not None: # pylint: disable=protected-access
raise ValueError("`tf.add_check_numerics_ops() is not compatible "
"with TensorFlow control flow operations such as "
"`tf.cond()` or `tf.while_loop()`.")
message = op.name + ":" + str(output.value_index)
with tf.control_dependencies(check_op):
check_op = [tf.check_numerics(output, message=message)]
return tf.group(*check_op)
def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs,
batch_size, num_samples, num_eval_samples,
resampling_schedule, use_resampling_grads,
learning_rate, lr_decay_steps, dtype="float64"):
num_timesteps = num_obs * steps_per_obs + 1
# Make the dataset.
dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
eval_dataset = data.make_long_chain_dataset(
state_size=state_size,
num_obs=num_obs,
steps_per_obs=steps_per_obs,
batch_size=batch_size,
num_samples=num_eval_samples,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=dtype,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
model = models.LongChainModel.create(
state_size,
num_obs,
steps_per_obs,
observation_type=FLAGS.observation_type,
transition_type=FLAGS.transition_type,
variance=FLAGS.variance,
observation_variance=FLAGS.observation_variance,
dtype=tf.as_dtype(dtype),
disable_r=FLAGS.disable_r)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif bound == "fivo" or "fivo-aux":
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=False)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
# We can't calculate the likelihood for most of these models
# so we just return zeros.
eval_likelihood = tf.zeros([], dtype=dtype)
return global_step, train_op, eval_log_p_hat, eval_likelihood
def create_graph(bound, state_size, num_timesteps, batch_size,
num_samples, num_eval_samples, resampling_schedule,
use_resampling_grads, learning_rate, lr_decay_steps,
train_p, dtype='float64'):
if FLAGS.use_bs:
true_bs = None
else:
true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)]
# Make the dataset.
true_bs, dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=batch_size,
num_samples=num_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
itr = dataset.make_one_shot_iterator()
_, observations = itr.get_next()
# Make the dataset for eval
_, eval_dataset = data.make_dataset(
bs=true_bs,
state_size=state_size,
num_timesteps=num_timesteps,
batch_size=num_eval_samples,
num_samples=num_eval_samples,
variance=FLAGS.variance,
prior_type=FLAGS.p_type,
bimodal_prior_weight=FLAGS.bimodal_prior_weight,
bimodal_prior_mean=FLAGS.bimodal_prior_mean,
transition_type=FLAGS.transition_type,
fixed_observation=FLAGS.fixed_observation,
dtype=dtype)
eval_itr = eval_dataset.make_one_shot_iterator()
_, eval_observations = eval_itr.get_next()
# Make the model.
if bound == "fivo-aux-td":
model = models.TDModel.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
else:
model = models.Model.create(
state_size,
num_timesteps,
variance=FLAGS.variance,
train_p=train_p,
p_type=FLAGS.p_type,
q_type=FLAGS.q_type,
mixing_coeff=FLAGS.bimodal_prior_weight,
prior_mode_mean=FLAGS.bimodal_prior_mean,
observation_variance=FLAGS.observation_variance,
transition_type=FLAGS.transition_type,
use_bs=FLAGS.use_bs,
r_sigma_init=FLAGS.r_sigma_init,
dtype=tf.as_dtype(dtype),
random_seed=FLAGS.random_seed)
# Compute the bound and loss
if bound == "iwae":
(_, losses, ema_op, _, _) = bounds.iwae(
model,
observations,
num_timesteps,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae(
model,
eval_observations,
num_timesteps,
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
elif "fivo" in bound:
if bound == "fivo-aux-td":
(_, losses, ema_op, _, _) = bounds.fivo_aux_td(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
num_samples=num_eval_samples,
summarize=True)
else:
(_, losses, ema_op, _, _) = bounds.fivo(
model,
observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=use_resampling_grads,
resampling_type=FLAGS.resampling_method,
aux=("aux" in bound),
num_samples=num_samples)
(eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo(
model,
eval_observations,
num_timesteps,
resampling_schedule=resampling_schedule,
use_resampling_grads=False,
resampling_type="multinomial",
aux=("aux" in bound),
num_samples=num_eval_samples,
summarize=True)
eval_log_p_hat = tf.reduce_mean(eval_log_p_hat)
summ.summarize_ess(eval_log_weights, only_last_timestep=True)
# if FLAGS.p_type == "bimodal":
# # create the observations that showcase the model.
# mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.],
# dtype=tf.float64)
# mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1)
# k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean)
# explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k
# explain_obs = tf.tile(explain_obs, [num_eval_samples, 1])
# # run the model on the explainable observations
# if bound == "iwae":
# (_, _, _, explain_states, explain_log_weights) = bounds.iwae(
# model,
# explain_obs,
# num_timesteps,
# num_samples=num_eval_samples)
# elif bound == "fivo" or "fivo-aux":
# (_, _, _, explain_states, explain_log_weights) = bounds.fivo(
# model,
# explain_obs,
# num_timesteps,
# resampling_schedule=resampling_schedule,
# use_resampling_grads=False,
# resampling_type="multinomial",
# aux=("aux" in bound),
# num_samples=num_eval_samples)
# summ.summarize_particles(explain_states,
# explain_log_weights,
# explain_obs,
# model)
# Calculate the true likelihood.
if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')):
eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps
else:
eval_likelihood = tf.zeros_like(eval_log_p_hat)
tf.summary.scalar("log_p_hat", eval_log_p_hat)
tf.summary.scalar("likelihood", eval_likelihood)
tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat)
summ.summarize_model(model, true_bs, eval_observations, eval_states, bound,
summarize_r=not bound == "fivo-aux-td")
# Compute and apply grads.
global_step = tf.train.get_or_create_global_step()
apply_grads = make_apply_grads_op(losses,
global_step,
learning_rate,
lr_decay_steps)
# Update the emas after applying the grads.
with tf.control_dependencies([apply_grads]):
train_op = tf.group(ema_op)
#train_op = tf.group(ema_op, add_check_numerics_ops())
return global_step, train_op, eval_log_p_hat, eval_likelihood
def parse_resampling_schedule(schedule, num_timesteps):
schedule = schedule.strip().lower()
if schedule == "always":
return [True] * (num_timesteps - 1) + [False]
elif schedule == "never":
return [False] * num_timesteps
elif "every" in schedule:
n = int(schedule.split("_")[1])
return [(i+1) % n == 0 for i in xrange(num_timesteps)]
else:
sched = [x.strip() == "true" for x in schedule.split(",")]
assert len(
sched
) == num_timesteps, "Wrong number of timesteps in resampling schedule."
return sched
def create_log_hook(step, eval_log_p_hat, eval_likelihood):
def summ_formatter(d):
return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d))
hook = tf.train.LoggingTensorHook(
{
"step": step,
"log_p_hat": eval_log_p_hat,
"likelihood": eval_likelihood,
},
every_n_iter=FLAGS.summarize_every,
formatter=summ_formatter)
return hook
def create_infrequent_summary_hook():
infrequent_summary_hook = tf.train.SummarySaverHook(
save_steps=10000,
output_dir=FLAGS.logdir,
summary_op=tf.summary.merge_all(key="infrequent_summaries")
)
return infrequent_summary_hook
def main(unused_argv):
if FLAGS.model == "long_chain":
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps + 1)
else:
resampling_schedule = parse_resampling_schedule(FLAGS.resampling,
FLAGS.num_timesteps)
if FLAGS.random_seed is None:
seed = np.random.randint(0, high=10000)
else:
seed = FLAGS.random_seed
tf.logging.info("Using random seed %d", seed)
if FLAGS.model == "long_chain":
assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type
assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models"
assert not FLAGS.use_bs, "Bs are not supported with long chain models"
assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match."
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models."
if FLAGS.model == "forward":
if "nonlinear" not in FLAGS.p_type:
assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model."
assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model."
assert FLAGS.observation_variance is None, "Forward model does not support observation variance."
assert FLAGS.num_observations == 1, "Forward model only supports 1 observation."
if "relaxed" in FLAGS.resampling_method:
FLAGS.use_resampling_grads = False
assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling."
if FLAGS.observation_variance is None:
FLAGS.observation_variance = FLAGS.variance
if FLAGS.p_type == "bimodal":
assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p."
if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy":
assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model."
g = tf.Graph()
with g.as_default():
# Set the seeds.
tf.set_random_seed(seed)
np.random.seed(seed)
if FLAGS.model == "long_chain":
(global_step, train_op, eval_log_p_hat,
eval_likelihood) = create_long_chain_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_observations,
FLAGS.steps_per_observation,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps)
else:
(global_step, train_op,
eval_log_p_hat, eval_likelihood) = create_graph(
FLAGS.bound,
FLAGS.state_size,
FLAGS.num_timesteps,
FLAGS.batch_size,
FLAGS.num_samples,
FLAGS.num_eval_samples,
resampling_schedule,
FLAGS.use_resampling_grads,
FLAGS.learning_rate,
FLAGS.decay_steps,
FLAGS.train_p)
log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)]
if len(tf.get_collection("infrequent_summaries")) > 0:
log_hooks.append(create_infrequent_summary_hook())
tf.logging.info("trainable variables:")
tf.logging.info([v.name for v in tf.trainable_variables()])
tf.logging.info("p vars:")
tf.logging.info([v.name for v in tf.get_collection("P_VARS")])
tf.logging.info("q vars:")
tf.logging.info([v.name for v in tf.get_collection("Q_VARS")])
tf.logging.info("r vars:")
tf.logging.info([v.name for v in tf.get_collection("R_VARS")])
tf.logging.info("r tilde vars:")
tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")])
with tf.train.MonitoredTrainingSession(
master="",
is_chief=True,
hooks=log_hooks,
checkpoint_dir=FLAGS.logdir,
save_checkpoint_secs=120,
save_summaries_steps=FLAGS.summarize_every,
log_step_count_steps=FLAGS.summarize_every) as sess:
cur_step = -1
while True:
if sess.should_stop() or cur_step > FLAGS.max_steps:
break
# run a step
_, cur_step = sess.run([train_op, global_step])
if __name__ == "__main__":
tf.app.run(main)