# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import sonnet as snt import tensorflow as tf import numpy as np import math SQUARED_OBSERVATION = "squared" ABS_OBSERVATION = "abs" STANDARD_OBSERVATION = "standard" OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION] ROUND_TRANSITION = "round" STANDARD_TRANSITION = "standard" TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION] class Q(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, dtype=tf.float32, random_seed=None, init_mu0_to_zero=False, graph_collection_name="Q_VARS"): self.sigma_min = sigma_min self.dtype = dtype self.graph_collection_name = graph_collection_name initializers = [] for t in xrange(num_timesteps): if t == 0 and init_mu0_to_zero: initializers.append( {"w": tf.zeros_initializer, "b": tf.zeros_initializer}) else: initializers.append( {"w": tf.random_uniform_initializer(seed=random_seed), "b": tf.zeros_initializer}) def custom_getter(getter, *args, **kwargs): out = getter(*args, **kwargs) ref = tf.get_collection_ref(self.graph_collection_name) if out not in ref: ref.append(out) return out self.mus = [ snt.Linear(output_size=state_size, initializers=initializers[t], name="q_mu_%d" % t, custom_getter=custom_getter ) for t in xrange(num_timesteps) ] self.sigmas = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="q_sigma_%d" % (t + 1), collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], initializer=tf.random_uniform_initializer(seed=random_seed)) for t in xrange(num_timesteps) ] def q_zt(self, observation, prev_state, t): batch_size = tf.shape(prev_state)[0] q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1)) q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) return q_zt def summarize_weights(self): for t, sigma in enumerate(self.sigmas): tf.summary.scalar("q_sigma/%d" % t, sigma[0]) for t, f in enumerate(self.mus): tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0]) if t != 0: tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0]) class PreviousStateQ(Q): def q_zt(self, unused_observation, prev_state, t): batch_size = tf.shape(prev_state)[0] q_mu = self.mus[t](prev_state) q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) return q_zt def summarize_weights(self): for t, sigma in enumerate(self.sigmas): tf.summary.scalar("q_sigma/%d" % t, sigma[0]) for t, f in enumerate(self.mus): tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0]) class ObservationQ(Q): def q_zt(self, observation, prev_state, t): batch_size = tf.shape(prev_state)[0] q_mu = self.mus[t](observation) q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) return q_zt def summarize_weights(self): for t, sigma in enumerate(self.sigmas): tf.summary.scalar("q_sigma/%d" % t, sigma[0]) for t, f in enumerate(self.mus): tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0]) class SimpleMeanQ(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, dtype=tf.float32, random_seed=None, init_mu0_to_zero=False, graph_collection_name="Q_VARS"): self.sigma_min = sigma_min self.dtype = dtype self.graph_collection_name = graph_collection_name initializers = [] for t in xrange(num_timesteps): if t == 0 and init_mu0_to_zero: initializers.append(tf.zeros_initializer) else: initializers.append(tf.random_uniform_initializer(seed=random_seed)) self.mus = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="q_mu_%d" % (t + 1), collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], initializer=initializers[t]) for t in xrange(num_timesteps) ] self.sigmas = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="q_sigma_%d" % (t + 1), collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], initializer=tf.random_uniform_initializer(seed=random_seed)) for t in xrange(num_timesteps) ] def q_zt(self, unused_observation, prev_state, t): batch_size = tf.shape(prev_state)[0] q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1]) q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) return q_zt def summarize_weights(self): for t, sigma in enumerate(self.sigmas): tf.summary.scalar("q_sigma/%d" % t, sigma[0]) for t, f in enumerate(self.mus): tf.summary.scalar("q_mu/%d" % t, f[0]) class R(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, dtype=tf.float32, sigma_init=1., random_seed=None, graph_collection_name="R_VARS"): self.dtype = dtype self.sigma_min = sigma_min initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), "b": tf.zeros_initializer} self.graph_collection_name=graph_collection_name def custom_getter(getter, *args, **kwargs): out = getter(*args, **kwargs) ref = tf.get_collection_ref(self.graph_collection_name) if out not in ref: ref.append(out) return out self.mus= [ snt.Linear(output_size=state_size, initializers=initializers, name="r_mu_%d" % t, custom_getter=custom_getter) for t in xrange(num_timesteps) ] self.sigmas = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="r_sigma_%d" % (t + 1), collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], #initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100)) initializer=tf.constant_initializer(sigma_init)) for t in xrange(num_timesteps) ] def r_xn(self, z_t, t): batch_size = tf.shape(z_t)[0] r_mu = self.mus[t](z_t) r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1]) return tf.contrib.distributions.Normal( loc=r_mu, scale=tf.sqrt(r_sigma)) def summarize_weights(self): for t in range(len(self.mus) - 1): tf.summary.scalar("r_mu/%d" % t, self.mus[t][0]) tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0]) class P(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, variance=1.0, dtype=tf.float32, random_seed=None, trainable=True, init_bs_to_zero=False, graph_collection_name="P_VARS"): self.state_size = state_size self.num_timesteps = num_timesteps self.sigma_min = sigma_min self.dtype = dtype self.variance = variance self.graph_collection_name = graph_collection_name if init_bs_to_zero: initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)] else: initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)] self.bs = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="p_b_%d" % (t + 1), initializer=initializers[t], collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], trainable=trainable) for t in xrange(num_timesteps) ] self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) def posterior(self, observation, prev_state, t): """Computes the true posterior p(z_t|z_{t-1}, z_n).""" # bs[0] is really b_1 # Bs[i] is sum from k=i+1^n b_k mu = observation - self.Bs[t] if t > 0: mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t) mu /= float(self.num_timesteps - t + 1) sigma = tf.ones_like(mu) * self.variance * ( float(self.num_timesteps - t) / float(self.num_timesteps - t + 1)) return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) def lookahead(self, state, t): """Computes the true lookahead distribution p(z_n|z_t).""" mu = state + self.Bs[t] sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t) return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) def likelihood(self, observation): batch_size = tf.shape(observation)[0] mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1) dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) # Average over the batch and take the sum over the state size return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1)) def p_zt(self, prev_state, t): """Computes the model p(z_t| z_{t-1}).""" batch_size = tf.shape(prev_state)[0] if t > 0: z_mu_p = prev_state + self.bs[t - 1] else: # p(z_0) is Normal(0,1) z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype) p_zt = tf.contrib.distributions.Normal( loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance)) return p_zt def generative(self, unused_observation, z_nm1): """Computes the model's generative distribution p(z_n| z_{n-1}).""" generative_p_mu = z_nm1 + self.bs[-1] return tf.contrib.distributions.Normal( loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance)) class ShortChainNonlinearP(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, variance=1.0, observation_variance=1.0, transition_type=STANDARD_TRANSITION, transition_dist=tf.contrib.distributions.Normal, dtype=tf.float32, random_seed=None): self.state_size = state_size self.num_timesteps = num_timesteps self.sigma_min = sigma_min self.dtype = dtype self.variance = variance self.observation_variance = observation_variance self.transition_type = transition_type self.transition_dist = transition_dist def p_zt(self, prev_state, t): """Computes the model p(z_t| z_{t-1}).""" batch_size = tf.shape(prev_state)[0] if t > 0: if self.transition_type == ROUND_TRANSITION: loc = tf.round(prev_state) tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance)) elif self.transition_type == STANDARD_TRANSITION: loc = prev_state tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance)) else: # p(z_0) is Normal(0,1) loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype) tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance) p_zt = self.transition_dist( loc=loc, scale=tf.sqrt(tf.ones_like(loc) * self.variance)) return p_zt def generative(self, unused_obs, z_ni): """Computes the model's generative distribution p(x_i| z_{ni}).""" if self.transition_type == ROUND_TRANSITION: loc = tf.round(z_ni) elif self.transition_type == STANDARD_TRANSITION: loc = z_ni generative_sigma_sq = tf.ones_like(loc) * self.observation_variance return self.transition_dist( loc=loc, scale=tf.sqrt(generative_sigma_sq)) class BimodalPriorP(object): def __init__(self, state_size, num_timesteps, mixing_coeff=0.5, prior_mode_mean=1, sigma_min=1e-5, variance=1.0, dtype=tf.float32, random_seed=None, trainable=True, init_bs_to_zero=False, graph_collection_name="P_VARS"): self.state_size = state_size self.num_timesteps = num_timesteps self.sigma_min = sigma_min self.dtype = dtype self.variance = variance self.mixing_coeff = mixing_coeff self.prior_mode_mean = prior_mode_mean if init_bs_to_zero: initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)] else: initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)] self.bs = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="b_%d" % (t + 1), initializer=initializers[t], collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], trainable=trainable) for t in xrange(num_timesteps) ] self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) def posterior(self, observation, prev_state, t): # NOTE: This is currently wrong, but would require a refactoring of # summarize_q to fix as kl is not defined for a mixture """Computes the true posterior p(z_t|z_{t-1}, z_n).""" # bs[0] is really b_1 # Bs[i] is sum from k=i+1^n b_k mu = observation - self.Bs[t] if t > 0: mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t) mu /= float(self.num_timesteps - t + 1) sigma = tf.ones_like(mu) * self.variance * ( float(self.num_timesteps - t) / float(self.num_timesteps - t + 1)) return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) def lookahead(self, state, t): """Computes the true lookahead distribution p(z_n|z_t).""" mu = state + self.Bs[t] sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t) return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) def likelihood(self, observation): batch_size = tf.shape(observation)[0] sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1) mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs zn_pos = tf.contrib.distributions.Normal( loc=mu_pos, scale=tf.sqrt(sigma)) zn_neg = tf.contrib.distributions.Normal( loc=mu_neg, scale=tf.sqrt(sigma)) mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64) mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs) zn_dist = tf.contrib.distributions.Mixture( cat=mode_selection_dist, components=[zn_pos, zn_neg], validate_args=True) # Average over the batch and take the sum over the state size return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1)) def p_zt(self, prev_state, t): """Computes the model p(z_t| z_{t-1}).""" batch_size = tf.shape(prev_state)[0] if t > 0: z_mu_p = prev_state + self.bs[t - 1] p_zt = tf.contrib.distributions.Normal( loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance)) return p_zt else: # p(z_0) is mixture of two Normals mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean z0_pos = tf.contrib.distributions.Normal( loc=mu_pos, scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance)) z0_neg = tf.contrib.distributions.Normal( loc=mu_neg, scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance)) mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64) mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs) z0_dist = tf.contrib.distributions.Mixture( cat=mode_selection_dist, components=[z0_pos, z0_neg], validate_args=False) return z0_dist def generative(self, unused_observation, z_nm1): """Computes the model's generative distribution p(z_n| z_{n-1}).""" generative_p_mu = z_nm1 + self.bs[-1] return tf.contrib.distributions.Normal( loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance)) class Model(object): def __init__(self, p, q, r, state_size, num_timesteps, dtype=tf.float32): self.p = p self.q = q self.r = r self.state_size = state_size self.num_timesteps = num_timesteps self.dtype = dtype def zero_state(self, batch_size): return tf.zeros([batch_size, self.state_size], dtype=self.dtype) def __call__(self, prev_state, observation, t): # Compute the q distribution over z, q(z_t|z_n, z_{t-1}). q_zt = self.q.q_zt(observation, prev_state, t) # Compute the p distribution over z, p(z_t|z_{t-1}). p_zt = self.p.p_zt(prev_state, t) # sample from q zt = q_zt.sample() r_xn = self.r.r_xn(zt, t) # Calculate the logprobs and sum over the state size. log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1) # If we're at the last timestep, also calc the logprob of the observation. if t == self.num_timesteps - 1: generative_dist = self.p.generative(observation, zt) log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1) else: log_p_x_given_z = tf.zeros_like(log_q_zt) return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn) @staticmethod def create(state_size, num_timesteps, sigma_min=1e-5, r_sigma_init=1, variance=1.0, mixing_coeff=0.5, prior_mode_mean=1.0, dtype=tf.float32, random_seed=None, train_p=True, p_type="unimodal", q_type="normal", observation_variance=1.0, transition_type=STANDARD_TRANSITION, use_bs=True): if p_type == "unimodal": p = P(state_size, num_timesteps, sigma_min=sigma_min, variance=variance, dtype=dtype, random_seed=random_seed, trainable=train_p, init_bs_to_zero=not use_bs) elif p_type == "bimodal": p = BimodalPriorP( state_size, num_timesteps, mixing_coeff=mixing_coeff, prior_mode_mean=prior_mode_mean, sigma_min=sigma_min, variance=variance, dtype=dtype, random_seed=random_seed, trainable=train_p, init_bs_to_zero=not use_bs) elif "nonlinear" in p_type: if "cauchy" in p_type: trans_dist = tf.contrib.distributions.Cauchy else: trans_dist = tf.contrib.distributions.Normal p = ShortChainNonlinearP( state_size, num_timesteps, sigma_min=sigma_min, variance=variance, observation_variance=observation_variance, transition_type=transition_type, transition_dist=trans_dist, dtype=dtype, random_seed=random_seed ) if q_type == "normal": q_class = Q elif q_type == "simple_mean": q_class = SimpleMeanQ elif q_type == "prev_state": q_class = PreviousStateQ elif q_type == "observation": q_class = ObservationQ q = q_class(state_size, num_timesteps, sigma_min=sigma_min, dtype=dtype, random_seed=random_seed, init_mu0_to_zero=not use_bs) r = R(state_size, num_timesteps, sigma_min=sigma_min, sigma_init=r_sigma_init, dtype=dtype, random_seed=random_seed) model = Model(p, q, r, state_size, num_timesteps, dtype=dtype) return model class BackwardsModel(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, dtype=tf.float32): self.state_size = state_size self.num_timesteps = num_timesteps self.sigma_min = sigma_min self.dtype = dtype self.bs = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="b_%d" % (t + 1), initializer=tf.zeros_initializer) for t in xrange(num_timesteps) ] self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) self.q_mus = [ snt.Linear(output_size=state_size) for _ in xrange(num_timesteps) ] self.q_sigmas = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="q_sigma_%d" % (t + 1), initializer=tf.zeros_initializer) for t in xrange(num_timesteps) ] self.r_mus = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="r_mu_%d" % (t + 1), initializer=tf.zeros_initializer) for t in xrange(num_timesteps) ] self.r_sigmas = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="r_sigma_%d" % (t + 1), initializer=tf.zeros_initializer) for t in xrange(num_timesteps) ] def zero_state(self, batch_size): return tf.zeros([batch_size, self.state_size], dtype=self.dtype) def posterior(self, unused_observation, prev_state, unused_t): # TODO(dieterichl): Correct this. return tf.contrib.distributions.Normal( loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state)) def lookahead(self, state, unused_t): # TODO(dieterichl): Correct this. return tf.contrib.distributions.Normal( loc=tf.zeros_like(state), scale=tf.zeros_like(state)) def q_zt(self, observation, next_state, t): """Computes the variational posterior q(z_{t}|z_{t+1}, z_n).""" t_backwards = self.num_timesteps - t - 1 batch_size = tf.shape(next_state)[0] q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1)) q_sigma = tf.maximum( tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min) q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) return q_zt def p_zt(self, prev_state, t): """Computes the model p(z_{t+1}| z_{t}).""" t_backwards = self.num_timesteps - t - 1 z_mu_p = prev_state + self.bs[t_backwards] p_zt = tf.contrib.distributions.Normal( loc=z_mu_p, scale=tf.ones_like(z_mu_p)) return p_zt def generative(self, unused_observation, z_nm1): """Computes the model's generative distribution p(z_n| z_{n-1}).""" generative_p_mu = z_nm1 + self.bs[-1] return tf.contrib.distributions.Normal( loc=generative_p_mu, scale=tf.ones_like(generative_p_mu)) def r(self, z_t, t): t_backwards = self.num_timesteps - t - 1 batch_size = tf.shape(z_t)[0] r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1]) r_sigma = tf.maximum( tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min) r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1]) return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma)) def likelihood(self, observation): batch_size = tf.shape(observation)[0] mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) sigma = tf.ones_like(mu) * (self.num_timesteps + 1) dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) # Average over the batch and take the sum over the state size return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1)) def __call__(self, next_state, observation, t): # next state = z_{t+1} # Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}). q_zt = self.q_zt(observation, next_state, t) # sample from q zt = q_zt.sample() # Compute the p distribution over z, p(z_{t+1}|z_{t}). p_zt = self.p_zt(zt, t) # Compute log p(z_{t+1} | z_t) if t == 0: log_p_zt = p_zt.log_prob(observation) else: log_p_zt = p_zt.log_prob(next_state) # Compute r prior over zt r_zt = self.r(zt, t) log_r_zt = r_zt.log_prob(zt) # Compute proposal density at zt log_q_zt = q_zt.log_prob(zt) # If we're at the last timestep, also calc the logprob of the observation. if t == self.num_timesteps - 1: p_z0_dist = tf.contrib.distributions.Normal( loc=tf.zeros_like(zt), scale=tf.ones_like(zt)) z0_log_prob = p_z0_dist.log_prob(zt) else: z0_log_prob = tf.zeros_like(log_q_zt) return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt) class LongChainP(object): def __init__(self, state_size, num_obs, steps_per_obs, sigma_min=1e-5, variance=1.0, observation_variance=1.0, observation_type=STANDARD_OBSERVATION, transition_type=STANDARD_TRANSITION, dtype=tf.float32, random_seed=None): self.state_size = state_size self.steps_per_obs = steps_per_obs self.num_obs = num_obs self.num_timesteps = steps_per_obs*num_obs + 1 self.sigma_min = sigma_min self.dtype = dtype self.variance = variance self.observation_variance = observation_variance self.observation_type = observation_type self.transition_type = transition_type def likelihood(self, observations): """Computes the model's true likelihood of the observations. Args: observations: A [batch_size, m, state_size] Tensor representing each of the m observations. Returns: logprob: The true likelihood of the observations given the model. """ raise ValueError("Likelihood is not defined for long-chain models") # batch_size = tf.shape(observations)[0] # mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype) # sigma = np.fromfunction( # lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1), # [self.num_obs, self.num_obs]) # sigma += np.eye(self.num_obs) # sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype) # sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...], # [batch_size, self.state_size, 1, 1]) # dist = tf.contrib.distributions.MultivariateNormalFullCovariance( # loc=mu, # covariance_matrix=sigma) # Average over the batch and take the sum over the state size #return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1)) def p_zt(self, prev_state, t): """Computes the model p(z_t| z_{t-1}).""" batch_size = tf.shape(prev_state)[0] if t > 0: if self.transition_type == ROUND_TRANSITION: loc = tf.round(prev_state) tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance)) elif self.transition_type == STANDARD_TRANSITION: loc = prev_state tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance)) else: # p(z_0) is Normal(0,1) loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype) tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance) p_zt = tf.contrib.distributions.Normal( loc=loc, scale=tf.sqrt(tf.ones_like(loc) * self.variance)) return p_zt def generative(self, z_ni, t): """Computes the model's generative distribution p(x_i| z_{ni}).""" if self.observation_type == SQUARED_OBSERVATION: generative_mu = tf.square(z_ni) tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance)) elif self.observation_type == ABS_OBSERVATION: generative_mu = tf.abs(z_ni) tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance)) elif self.observation_type == STANDARD_OBSERVATION: generative_mu = z_ni tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance)) generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance return tf.contrib.distributions.Normal( loc=generative_mu, scale=tf.sqrt(generative_sigma_sq)) class LongChainQ(object): def __init__(self, state_size, num_obs, steps_per_obs, sigma_min=1e-5, dtype=tf.float32, random_seed=None): self.state_size = state_size self.sigma_min = sigma_min self.dtype = dtype self.steps_per_obs = steps_per_obs self.num_obs = num_obs self.num_timesteps = num_obs*steps_per_obs +1 initializers = { "w": tf.random_uniform_initializer(seed=random_seed), "b": tf.zeros_initializer } self.mus = [ snt.Linear(output_size=state_size, initializers=initializers) for t in xrange(self.num_timesteps) ] self.sigmas = [ tf.get_variable( shape=[state_size], dtype=self.dtype, name="q_sigma_%d" % (t + 1), initializer=tf.random_uniform_initializer(seed=random_seed)) for t in xrange(self.num_timesteps) ] def first_relevant_obs_index(self, t): return int(max((t-1)/self.steps_per_obs, 0)) def q_zt(self, observations, prev_state, t): """Computes a distribution over z_t. Args: observations: a [batch_size, num_observations, state_size] Tensor. prev_state: a [batch_size, state_size] Tensor. t: The current timestep, an int Tensor. """ # filter out unneeded past obs first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs)) num_relevant_observations = self.num_obs - first_relevant_obs_index observations = observations[:,first_relevant_obs_index:,:] batch_size = tf.shape(prev_state)[0] # concatenate the prev state and observations along the second axis (that is # not the batch or state size axis, and then flatten it to # [batch_size, (num_relevant_observations + 1) * state_size] to feed it into # the linear layer. q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1) q_input = tf.reshape(q_input, [batch_size, (num_relevant_observations + 1) * self.state_size]) q_mu = self.mus[t](q_input) q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) tf.logging.info( "q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format( **{"t": t, "tm1": t-1, "obsf": (first_relevant_obs_index+1)*self.steps_per_obs, "obst":self.steps_per_obs*self.num_obs})) return q_zt def summarize_weights(self): pass class LongChainR(object): def __init__(self, state_size, num_obs, steps_per_obs, sigma_min=1e-5, dtype=tf.float32, random_seed=None): self.state_size = state_size self.dtype = dtype self.sigma_min = sigma_min self.steps_per_obs = steps_per_obs self.num_obs = num_obs self.num_timesteps = num_obs*steps_per_obs + 1 self.sigmas = [ tf.get_variable( shape=[self.num_future_obs(t)], dtype=self.dtype, name="r_sigma_%d" % (t + 1), #initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100)) initializer=tf.constant_initializer(1.0)) for t in range(self.num_timesteps) ] def first_future_obs_index(self, t): return int(math.floor(t / self.steps_per_obs)) def num_future_obs(self, t): return int(self.num_obs - self.first_future_obs_index(t)) def r_xn(self, z_t, t): """Computes a distribution over the future observations given current latent state. The indexing in these messages is 1 indexed and inclusive. This is consistent with the latex documents. Args: z_t: [batch_size, state_size] Tensor t: Current timestep """ tf.logging.info( "r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format( **{"t": t, "start": (self.first_future_obs_index(t)+1)*self.steps_per_obs, "end": self.num_timesteps-1})) batch_size = tf.shape(z_t)[0] # the mean for all future observations is the same. # this tiling results in a [batch_size, num_future_obs, state_size] Tensor r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1]) # compute the variance r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) # the variance is the same across all state dimensions, so we only have to # time sigma to be [batch_size, num_future_obs]. r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size]) return tf.contrib.distributions.Normal( loc=r_mu, scale=tf.sqrt(r_sigma)) def summarize_weights(self): pass class LongChainModel(object): def __init__(self, p, q, r, state_size, num_obs, steps_per_obs, dtype=tf.float32, disable_r=False): self.p = p self.q = q self.r = r self.disable_r = disable_r self.state_size = state_size self.num_obs = num_obs self.steps_per_obs = steps_per_obs self.num_timesteps = steps_per_obs*num_obs + 1 self.dtype = dtype def zero_state(self, batch_size): return tf.zeros([batch_size, self.state_size], dtype=self.dtype) def next_obs_ind(self, t): return int(math.floor(max(t-1,0)/self.steps_per_obs)) def __call__(self, prev_state, observations, t): """Computes the importance weight for the model system. Args: prev_state: [batch_size, state_size] Tensor observations: [batch_size, num_observations, state_size] Tensor """ # Compute the q distribution over z, q(z_t|z_n, z_{t-1}). q_zt = self.q.q_zt(observations, prev_state, t) # Compute the p distribution over z, p(z_t|z_{t-1}). p_zt = self.p.p_zt(prev_state, t) # sample from q and evaluate the logprobs, summing over the state size zt = q_zt.sample() log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) if not self.disable_r and t < self.num_timesteps-1: # score the remaining observations using r r_xn = self.r.r_xn(zt, t) log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :]) # sum over state size and observation, leaving the batch index log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2]) else: log_r_xn = tf.zeros_like(log_p_zt) if t != 0 and t % self.steps_per_obs == 0: generative_dist = self.p.generative(zt, t) log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:]) log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1) else: log_p_x_given_z = tf.zeros_like(log_q_zt) return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn) @staticmethod def create(state_size, num_obs, steps_per_obs, sigma_min=1e-5, variance=1.0, observation_variance=1.0, observation_type=STANDARD_OBSERVATION, transition_type=STANDARD_TRANSITION, dtype=tf.float32, random_seed=None, disable_r=False): p = LongChainP( state_size, num_obs, steps_per_obs, sigma_min=sigma_min, variance=variance, observation_variance=observation_variance, observation_type=observation_type, transition_type=transition_type, dtype=dtype, random_seed=random_seed) q = LongChainQ( state_size, num_obs, steps_per_obs, sigma_min=sigma_min, dtype=dtype, random_seed=random_seed) r = LongChainR( state_size, num_obs, steps_per_obs, sigma_min=sigma_min, dtype=dtype, random_seed=random_seed) model = LongChainModel( p, q, r, state_size, num_obs, steps_per_obs, dtype=dtype, disable_r=disable_r) return model class RTilde(object): def __init__(self, state_size, num_timesteps, sigma_min=1e-5, dtype=tf.float32, random_seed=None, graph_collection_name="R_TILDE_VARS"): self.dtype = dtype self.sigma_min = sigma_min initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), "b": tf.zeros_initializer} self.graph_collection_name=graph_collection_name def custom_getter(getter, *args, **kwargs): out = getter(*args, **kwargs) ref = tf.get_collection_ref(self.graph_collection_name) if out not in ref: ref.append(out) return out self.fns = [ snt.Linear(output_size=2*state_size, initializers=initializers, name="r_tilde_%d" % t, custom_getter=custom_getter) for t in xrange(num_timesteps) ] def r_zt(self, z_t, observation, t): #out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1))) out = self.fns[t](tf.concat([z_t, observation], axis=1)) mu, raw_sigma_sq = tf.split(out, 2, axis=1) sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min) return mu, sigma_sq class TDModel(object): def __init__(self, p, q, r_tilde, state_size, num_timesteps, dtype=tf.float32, disable_r=False): self.p = p self.q = q self.r_tilde = r_tilde self.disable_r = disable_r self.state_size = state_size self.num_timesteps = num_timesteps self.dtype = dtype def zero_state(self, batch_size): return tf.zeros([batch_size, self.state_size], dtype=self.dtype) def __call__(self, prev_state, observation, t): """Computes the importance weight for the model system. Args: prev_state: [batch_size, state_size] Tensor observations: [batch_size, num_observations, state_size] Tensor """ # Compute the q distribution over z, q(z_t|z_n, z_{t-1}). q_zt = self.q.q_zt(observation, prev_state, t) # Compute the p distribution over z, p(z_t|z_{t-1}). p_zt = self.p.p_zt(prev_state, t) # sample from q and evaluate the logprobs, summing over the state size zt = q_zt.sample() # If it isn't the last timestep, compute the distribution over the next z. if t < self.num_timesteps - 1: p_ztplus1 = self.p.p_zt(zt, t+1) else: p_ztplus1 = None log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) if not self.disable_r and t < self.num_timesteps-1: # score the remaining observations using r r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1) else: r_tilde_mu = None r_tilde_sigma_sq = None if t == self.num_timesteps - 1: generative_dist = self.p.generative(observation, zt) log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1) else: log_p_x_given_z = tf.zeros_like(log_q_zt) return (zt, log_q_zt, log_p_zt, log_p_x_given_z, r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) @staticmethod def create(state_size, num_timesteps, sigma_min=1e-5, variance=1.0, dtype=tf.float32, random_seed=None, train_p=True, p_type="unimodal", q_type="normal", mixing_coeff=0.5, prior_mode_mean=1.0, observation_variance=1.0, transition_type=STANDARD_TRANSITION, use_bs=True): if p_type == "unimodal": p = P(state_size, num_timesteps, sigma_min=sigma_min, variance=variance, dtype=dtype, random_seed=random_seed, trainable=train_p, init_bs_to_zero=not use_bs) elif p_type == "bimodal": p = BimodalPriorP( state_size, num_timesteps, mixing_coeff=mixing_coeff, prior_mode_mean=prior_mode_mean, sigma_min=sigma_min, variance=variance, dtype=dtype, random_seed=random_seed, trainable=train_p, init_bs_to_zero=not use_bs) elif "nonlinear" in p_type: if "cauchy" in p_type: trans_dist = tf.contrib.distributions.Cauchy else: trans_dist = tf.contrib.distributions.Normal p = ShortChainNonlinearP( state_size, num_timesteps, sigma_min=sigma_min, variance=variance, observation_variance=observation_variance, transition_type=transition_type, transition_dist=trans_dist, dtype=dtype, random_seed=random_seed ) if q_type == "normal": q_class = Q elif q_type == "simple_mean": q_class = SimpleMeanQ elif q_type == "prev_state": q_class = PreviousStateQ elif q_type == "observation": q_class = ObservationQ q = q_class(state_size, num_timesteps, sigma_min=sigma_min, dtype=dtype, random_seed=random_seed, init_mu0_to_zero=not use_bs) r_tilde = RTilde( state_size, num_timesteps, sigma_min=sigma_min, dtype=dtype, random_seed=random_seed) model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype) return model