# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Main script for running fivo""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from collections import defaultdict import numpy as np import tensorflow as tf import bounds import data import models import summary_utils as summ tf.logging.set_verbosity(tf.logging.INFO) tf.app.flags.DEFINE_integer("random_seed", None, "A random seed for the data generating process. Same seed " "-> same data generating process and initialization.") tf.app.flags.DEFINE_enum("bound", "fivo", ["iwae", "fivo", "fivo-aux", "fivo-aux-td"], "The bound to optimize.") tf.app.flags.DEFINE_enum("model", "forward", ["forward", "long_chain"], "The model to use.") tf.app.flags.DEFINE_enum("q_type", "normal", ["normal", "simple_mean", "prev_state", "observation"], "The parameterization to use for q") tf.app.flags.DEFINE_enum("p_type", "unimodal", ["unimodal", "bimodal", "nonlinear"], "The type of prior.") tf.app.flags.DEFINE_boolean("train_p", True, "If false, do not train the model p.") tf.app.flags.DEFINE_integer("state_size", 1, "The dimensionality of the state space.") tf.app.flags.DEFINE_float("variance", 1.0, "The variance of the data generating process.") tf.app.flags.DEFINE_boolean("use_bs", True, "If False, initialize all bs to 0.") tf.app.flags.DEFINE_float("bimodal_prior_weight", 0.5, "The weight assigned to the positive mode of the prior in " "both the data generating process and p.") tf.app.flags.DEFINE_float("bimodal_prior_mean", None, "If supplied, sets the mean of the 2 modes of the prior to " "be 1 and -1 times the supplied value. This is for both the " "data generating process and p.") tf.app.flags.DEFINE_float("fixed_observation", None, "If supplied, fix the observation to a constant value in the" " data generating process only.") tf.app.flags.DEFINE_float("r_sigma_init", 1., "Value to initialize variance of r to.") tf.app.flags.DEFINE_enum("observation_type", models.STANDARD_OBSERVATION, models.OBSERVATION_TYPES, "The type of observation for the long chain model.") tf.app.flags.DEFINE_enum("transition_type", models.STANDARD_TRANSITION, models.TRANSITION_TYPES, "The type of transition for the long chain model.") tf.app.flags.DEFINE_float("observation_variance", None, "The variance of the observation. Defaults to 'variance'") tf.app.flags.DEFINE_integer("num_timesteps", 5, "Number of timesteps in the sequence.") tf.app.flags.DEFINE_integer("num_observations", 1, "The number of observations.") tf.app.flags.DEFINE_integer("steps_per_observation", 5, "The number of timesteps between each observation.") tf.app.flags.DEFINE_integer("batch_size", 4, "The number of examples per batch.") tf.app.flags.DEFINE_integer("num_samples", 4, "The number particles to use.") tf.app.flags.DEFINE_integer("num_eval_samples", 512, "The batch size and # of particles to use for eval.") tf.app.flags.DEFINE_string("resampling", "always", "How to resample. Accepts 'always','never', or a " "comma-separated list of booleans like 'true,true,false'.") tf.app.flags.DEFINE_enum("resampling_method", "multinomial", ["multinomial", "stratified", "systematic", "relaxed-logblend", "relaxed-stateblend", "relaxed-linearblend", "relaxed-stateblend-st",], "Type of resampling method to use.") tf.app.flags.DEFINE_boolean("use_resampling_grads", True, "Whether or not to use resampling grads to optimize FIVO." "Disabled automatically if resampling_method=relaxed.") tf.app.flags.DEFINE_boolean("disable_r", False, "If false, r is not used for fivo-aux and is set to zeros.") tf.app.flags.DEFINE_float("learning_rate", 1e-4, "The learning rate to use for ADAM or SGD.") tf.app.flags.DEFINE_integer("decay_steps", 25000, "The number of steps before the learning rate is halved.") tf.app.flags.DEFINE_integer("max_steps", int(1e6), "The number of steps to run training for.") tf.app.flags.DEFINE_string("logdir", "/tmp/fivo-aux", "Directory for summaries and checkpoints.") tf.app.flags.DEFINE_integer("summarize_every", int(1e3), "The number of steps between each evaluation.") FLAGS = tf.app.flags.FLAGS def combine_grad_lists(grad_lists): # grads is num_losses by num_variables. # each list could have different variables. # for each variable, sum the grads across all losses. grads_dict = defaultdict(list) var_dict = {} for grad_list in grad_lists: for grad, var in grad_list: if grad is not None: grads_dict[var.name].append(grad) var_dict[var.name] = var final_grads = [] for var_name, var in var_dict.iteritems(): grads = grads_dict[var_name] if len(grads) > 0: tf.logging.info("Var %s has combined grads from %s." % (var_name, [g.name for g in grads])) grad = tf.reduce_sum(grads, axis=0) else: tf.logging.info("Var %s has no grads" % var_name) grad = None final_grads.append((grad, var)) return final_grads def make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps): for l in losses: assert isinstance(l, bounds.Loss) lr = tf.train.exponential_decay( learning_rate, global_step, lr_decay_steps, 0.5, staircase=False) tf.summary.scalar("learning_rate", lr) opt = tf.train.AdamOptimizer(lr) ema_ops = [] grads = [] for loss_name, loss, loss_var_collection in losses: tf.logging.info("Computing grads of %s w.r.t. vars in collection %s" % (loss_name, loss_var_collection)) g = opt.compute_gradients(loss, var_list=tf.get_collection(loss_var_collection)) ema_ops.append(summ.summarize_grads(g, loss_name)) grads.append(g) all_grads = combine_grad_lists(grads) apply_grads_op = opt.apply_gradients(all_grads, global_step=global_step) # Update the emas after applying the grads. with tf.control_dependencies([apply_grads_op]): train_op = tf.group(*ema_ops) return train_op def add_check_numerics_ops(): check_op = [] for op in tf.get_default_graph().get_operations(): bad = ["logits/Log", "sample/Reshape", "log_prob/mul", "log_prob/SparseSoftmaxCrossEntropyWithLogits/Reshape", "entropy/Reshape", "entropy/LogSoftmax", "Categorical", "Mean"] if all([x not in op.name for x in bad]): for output in op.outputs: if output.dtype in [tf.float16, tf.float32, tf.float64]: if op._get_control_flow_context() is not None: # pylint: disable=protected-access raise ValueError("`tf.add_check_numerics_ops() is not compatible " "with TensorFlow control flow operations such as " "`tf.cond()` or `tf.while_loop()`.") message = op.name + ":" + str(output.value_index) with tf.control_dependencies(check_op): check_op = [tf.check_numerics(output, message=message)] return tf.group(*check_op) def create_long_chain_graph(bound, state_size, num_obs, steps_per_obs, batch_size, num_samples, num_eval_samples, resampling_schedule, use_resampling_grads, learning_rate, lr_decay_steps, dtype="float64"): num_timesteps = num_obs * steps_per_obs + 1 # Make the dataset. dataset = data.make_long_chain_dataset( state_size=state_size, num_obs=num_obs, steps_per_obs=steps_per_obs, batch_size=batch_size, num_samples=num_samples, variance=FLAGS.variance, observation_variance=FLAGS.observation_variance, dtype=dtype, observation_type=FLAGS.observation_type, transition_type=FLAGS.transition_type, fixed_observation=FLAGS.fixed_observation) itr = dataset.make_one_shot_iterator() _, observations = itr.get_next() # Make the dataset for eval eval_dataset = data.make_long_chain_dataset( state_size=state_size, num_obs=num_obs, steps_per_obs=steps_per_obs, batch_size=batch_size, num_samples=num_eval_samples, variance=FLAGS.variance, observation_variance=FLAGS.observation_variance, dtype=dtype, observation_type=FLAGS.observation_type, transition_type=FLAGS.transition_type, fixed_observation=FLAGS.fixed_observation) eval_itr = eval_dataset.make_one_shot_iterator() _, eval_observations = eval_itr.get_next() # Make the model. model = models.LongChainModel.create( state_size, num_obs, steps_per_obs, observation_type=FLAGS.observation_type, transition_type=FLAGS.transition_type, variance=FLAGS.variance, observation_variance=FLAGS.observation_variance, dtype=tf.as_dtype(dtype), disable_r=FLAGS.disable_r) # Compute the bound and loss if bound == "iwae": (_, losses, ema_op, _, _) = bounds.iwae( model, observations, num_timesteps, num_samples=num_samples) (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.iwae( model, eval_observations, num_timesteps, num_samples=num_eval_samples, summarize=False) eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) elif bound == "fivo" or "fivo-aux": (_, losses, ema_op, _, _) = bounds.fivo( model, observations, num_timesteps, resampling_schedule=resampling_schedule, use_resampling_grads=use_resampling_grads, resampling_type=FLAGS.resampling_method, aux=("aux" in bound), num_samples=num_samples) (eval_log_p_hat, _, _, _, eval_log_weights) = bounds.fivo( model, eval_observations, num_timesteps, resampling_schedule=resampling_schedule, use_resampling_grads=False, resampling_type="multinomial", aux=("aux" in bound), num_samples=num_eval_samples, summarize=False) eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) summ.summarize_ess(eval_log_weights, only_last_timestep=True) tf.summary.scalar("log_p_hat", eval_log_p_hat) # Compute and apply grads. global_step = tf.train.get_or_create_global_step() apply_grads = make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps) # Update the emas after applying the grads. with tf.control_dependencies([apply_grads]): train_op = tf.group(ema_op) # We can't calculate the likelihood for most of these models # so we just return zeros. eval_likelihood = tf.zeros([], dtype=dtype) return global_step, train_op, eval_log_p_hat, eval_likelihood def create_graph(bound, state_size, num_timesteps, batch_size, num_samples, num_eval_samples, resampling_schedule, use_resampling_grads, learning_rate, lr_decay_steps, train_p, dtype='float64'): if FLAGS.use_bs: true_bs = None else: true_bs = [np.zeros([state_size]).astype(dtype) for _ in xrange(num_timesteps)] # Make the dataset. true_bs, dataset = data.make_dataset( bs=true_bs, state_size=state_size, num_timesteps=num_timesteps, batch_size=batch_size, num_samples=num_samples, variance=FLAGS.variance, prior_type=FLAGS.p_type, bimodal_prior_weight=FLAGS.bimodal_prior_weight, bimodal_prior_mean=FLAGS.bimodal_prior_mean, transition_type=FLAGS.transition_type, fixed_observation=FLAGS.fixed_observation, dtype=dtype) itr = dataset.make_one_shot_iterator() _, observations = itr.get_next() # Make the dataset for eval _, eval_dataset = data.make_dataset( bs=true_bs, state_size=state_size, num_timesteps=num_timesteps, batch_size=num_eval_samples, num_samples=num_eval_samples, variance=FLAGS.variance, prior_type=FLAGS.p_type, bimodal_prior_weight=FLAGS.bimodal_prior_weight, bimodal_prior_mean=FLAGS.bimodal_prior_mean, transition_type=FLAGS.transition_type, fixed_observation=FLAGS.fixed_observation, dtype=dtype) eval_itr = eval_dataset.make_one_shot_iterator() _, eval_observations = eval_itr.get_next() # Make the model. if bound == "fivo-aux-td": model = models.TDModel.create( state_size, num_timesteps, variance=FLAGS.variance, train_p=train_p, p_type=FLAGS.p_type, q_type=FLAGS.q_type, mixing_coeff=FLAGS.bimodal_prior_weight, prior_mode_mean=FLAGS.bimodal_prior_mean, observation_variance=FLAGS.observation_variance, transition_type=FLAGS.transition_type, use_bs=FLAGS.use_bs, dtype=tf.as_dtype(dtype), random_seed=FLAGS.random_seed) else: model = models.Model.create( state_size, num_timesteps, variance=FLAGS.variance, train_p=train_p, p_type=FLAGS.p_type, q_type=FLAGS.q_type, mixing_coeff=FLAGS.bimodal_prior_weight, prior_mode_mean=FLAGS.bimodal_prior_mean, observation_variance=FLAGS.observation_variance, transition_type=FLAGS.transition_type, use_bs=FLAGS.use_bs, r_sigma_init=FLAGS.r_sigma_init, dtype=tf.as_dtype(dtype), random_seed=FLAGS.random_seed) # Compute the bound and loss if bound == "iwae": (_, losses, ema_op, _, _) = bounds.iwae( model, observations, num_timesteps, num_samples=num_samples) (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.iwae( model, eval_observations, num_timesteps, num_samples=num_eval_samples, summarize=True) eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) elif "fivo" in bound: if bound == "fivo-aux-td": (_, losses, ema_op, _, _) = bounds.fivo_aux_td( model, observations, num_timesteps, resampling_schedule=resampling_schedule, num_samples=num_samples) (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo_aux_td( model, eval_observations, num_timesteps, resampling_schedule=resampling_schedule, num_samples=num_eval_samples, summarize=True) else: (_, losses, ema_op, _, _) = bounds.fivo( model, observations, num_timesteps, resampling_schedule=resampling_schedule, use_resampling_grads=use_resampling_grads, resampling_type=FLAGS.resampling_method, aux=("aux" in bound), num_samples=num_samples) (eval_log_p_hat, _, _, eval_states, eval_log_weights) = bounds.fivo( model, eval_observations, num_timesteps, resampling_schedule=resampling_schedule, use_resampling_grads=False, resampling_type="multinomial", aux=("aux" in bound), num_samples=num_eval_samples, summarize=True) eval_log_p_hat = tf.reduce_mean(eval_log_p_hat) summ.summarize_ess(eval_log_weights, only_last_timestep=True) # if FLAGS.p_type == "bimodal": # # create the observations that showcase the model. # mode_odds_ratio = tf.convert_to_tensor([1., 3., 1./3., 512., 1./512.], # dtype=tf.float64) # mode_odds_ratio = tf.expand_dims(mode_odds_ratio, 1) # k = ((num_timesteps+1) * FLAGS.variance) / (2*FLAGS.bimodal_prior_mean) # explain_obs = tf.reduce_sum(model.p.bs) + tf.log(mode_odds_ratio) * k # explain_obs = tf.tile(explain_obs, [num_eval_samples, 1]) # # run the model on the explainable observations # if bound == "iwae": # (_, _, _, explain_states, explain_log_weights) = bounds.iwae( # model, # explain_obs, # num_timesteps, # num_samples=num_eval_samples) # elif bound == "fivo" or "fivo-aux": # (_, _, _, explain_states, explain_log_weights) = bounds.fivo( # model, # explain_obs, # num_timesteps, # resampling_schedule=resampling_schedule, # use_resampling_grads=False, # resampling_type="multinomial", # aux=("aux" in bound), # num_samples=num_eval_samples) # summ.summarize_particles(explain_states, # explain_log_weights, # explain_obs, # model) # Calculate the true likelihood. if hasattr(model.p, 'likelihood') and callable(getattr(model.p, 'likelihood')): eval_likelihood = model.p.likelihood(eval_observations)/ FLAGS.num_timesteps else: eval_likelihood = tf.zeros_like(eval_log_p_hat) tf.summary.scalar("log_p_hat", eval_log_p_hat) tf.summary.scalar("likelihood", eval_likelihood) tf.summary.scalar("bound_gap", eval_likelihood - eval_log_p_hat) summ.summarize_model(model, true_bs, eval_observations, eval_states, bound, summarize_r=not bound == "fivo-aux-td") # Compute and apply grads. global_step = tf.train.get_or_create_global_step() apply_grads = make_apply_grads_op(losses, global_step, learning_rate, lr_decay_steps) # Update the emas after applying the grads. with tf.control_dependencies([apply_grads]): train_op = tf.group(ema_op) #train_op = tf.group(ema_op, add_check_numerics_ops()) return global_step, train_op, eval_log_p_hat, eval_likelihood def parse_resampling_schedule(schedule, num_timesteps): schedule = schedule.strip().lower() if schedule == "always": return [True] * (num_timesteps - 1) + [False] elif schedule == "never": return [False] * num_timesteps elif "every" in schedule: n = int(schedule.split("_")[1]) return [(i+1) % n == 0 for i in xrange(num_timesteps)] else: sched = [x.strip() == "true" for x in schedule.split(",")] assert len( sched ) == num_timesteps, "Wrong number of timesteps in resampling schedule." return sched def create_log_hook(step, eval_log_p_hat, eval_likelihood): def summ_formatter(d): return ("Step {step}, log p_hat: {log_p_hat:.5f} likelihood: {likelihood:.5f}".format(**d)) hook = tf.train.LoggingTensorHook( { "step": step, "log_p_hat": eval_log_p_hat, "likelihood": eval_likelihood, }, every_n_iter=FLAGS.summarize_every, formatter=summ_formatter) return hook def create_infrequent_summary_hook(): infrequent_summary_hook = tf.train.SummarySaverHook( save_steps=10000, output_dir=FLAGS.logdir, summary_op=tf.summary.merge_all(key="infrequent_summaries") ) return infrequent_summary_hook def main(unused_argv): if FLAGS.model == "long_chain": resampling_schedule = parse_resampling_schedule(FLAGS.resampling, FLAGS.num_timesteps + 1) else: resampling_schedule = parse_resampling_schedule(FLAGS.resampling, FLAGS.num_timesteps) if FLAGS.random_seed is None: seed = np.random.randint(0, high=10000) else: seed = FLAGS.random_seed tf.logging.info("Using random seed %d", seed) if FLAGS.model == "long_chain": assert FLAGS.q_type == "normal", "Q type %s not supported for long chain models" % FLAGS.q_type assert FLAGS.p_type == "unimodal", "Bimodal priors are not supported for long chain models" assert not FLAGS.use_bs, "Bs are not supported with long chain models" assert FLAGS.num_timesteps == FLAGS.num_observations * FLAGS.steps_per_observation, "Num timesteps does not match." assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with long chain models." if FLAGS.model == "forward": if "nonlinear" not in FLAGS.p_type: assert FLAGS.transition_type == models.STANDARD_TRANSITION, "Non-standard transitions not supported by the forward model." assert FLAGS.observation_type == models.STANDARD_OBSERVATION, "Non-standard observations not supported by the forward model." assert FLAGS.observation_variance is None, "Forward model does not support observation variance." assert FLAGS.num_observations == 1, "Forward model only supports 1 observation." if "relaxed" in FLAGS.resampling_method: FLAGS.use_resampling_grads = False assert FLAGS.bound != "fivo-aux-td", "TD Training is not compatible with relaxed resampling." if FLAGS.observation_variance is None: FLAGS.observation_variance = FLAGS.variance if FLAGS.p_type == "bimodal": assert FLAGS.bimodal_prior_mean is not None, "Must specify prior mean if using bimodal p." if FLAGS.p_type == "nonlinear" or FLAGS.p_type == "nonlinear-cauchy": assert not FLAGS.use_bs, "Using bs is not compatible with the nonlinear model." g = tf.Graph() with g.as_default(): # Set the seeds. tf.set_random_seed(seed) np.random.seed(seed) if FLAGS.model == "long_chain": (global_step, train_op, eval_log_p_hat, eval_likelihood) = create_long_chain_graph( FLAGS.bound, FLAGS.state_size, FLAGS.num_observations, FLAGS.steps_per_observation, FLAGS.batch_size, FLAGS.num_samples, FLAGS.num_eval_samples, resampling_schedule, FLAGS.use_resampling_grads, FLAGS.learning_rate, FLAGS.decay_steps) else: (global_step, train_op, eval_log_p_hat, eval_likelihood) = create_graph( FLAGS.bound, FLAGS.state_size, FLAGS.num_timesteps, FLAGS.batch_size, FLAGS.num_samples, FLAGS.num_eval_samples, resampling_schedule, FLAGS.use_resampling_grads, FLAGS.learning_rate, FLAGS.decay_steps, FLAGS.train_p) log_hooks = [create_log_hook(global_step, eval_log_p_hat, eval_likelihood)] if len(tf.get_collection("infrequent_summaries")) > 0: log_hooks.append(create_infrequent_summary_hook()) tf.logging.info("trainable variables:") tf.logging.info([v.name for v in tf.trainable_variables()]) tf.logging.info("p vars:") tf.logging.info([v.name for v in tf.get_collection("P_VARS")]) tf.logging.info("q vars:") tf.logging.info([v.name for v in tf.get_collection("Q_VARS")]) tf.logging.info("r vars:") tf.logging.info([v.name for v in tf.get_collection("R_VARS")]) tf.logging.info("r tilde vars:") tf.logging.info([v.name for v in tf.get_collection("R_TILDE_VARS")]) with tf.train.MonitoredTrainingSession( master="", is_chief=True, hooks=log_hooks, checkpoint_dir=FLAGS.logdir, save_checkpoint_secs=120, save_summaries_steps=FLAGS.summarize_every, log_step_count_steps=FLAGS.summarize_every) as sess: cur_step = -1 while True: if sess.should_stop() or cur_step > FLAGS.max_steps: break # run a step _, cur_step = sess.run([train_op, global_step]) if __name__ == "__main__": tf.app.run(main)