|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Datasets.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
import models |
|
|
|
|
|
def make_long_chain_dataset( |
|
state_size=1, |
|
num_obs=5, |
|
steps_per_obs=3, |
|
variance=1., |
|
observation_variance=1., |
|
batch_size=4, |
|
num_samples=1, |
|
observation_type=models.STANDARD_OBSERVATION, |
|
transition_type=models.STANDARD_TRANSITION, |
|
fixed_observation=None, |
|
dtype="float32"): |
|
"""Creates a long chain data generating process. |
|
|
|
Creates a tf.data.Dataset that provides batches of data from a long |
|
chain. |
|
|
|
Args: |
|
state_size: The dimension of the state space of the process. |
|
num_obs: The number of observations in the chain. |
|
steps_per_obs: The number of steps between each observation. |
|
variance: The variance of the normal distributions used at each timestep. |
|
batch_size: The number of trajectories to include in each batch. |
|
num_samples: The number of replicas of each trajectory to include in each |
|
batch. |
|
dtype: The datatype of the states and observations. |
|
Returns: |
|
dataset: A tf.data.Dataset that can be iterated over. |
|
""" |
|
num_timesteps = num_obs * steps_per_obs |
|
def data_generator(): |
|
"""An infinite generator of latents and observations from the model.""" |
|
while True: |
|
states = [] |
|
observations = [] |
|
|
|
states.append( |
|
np.random.normal(size=[state_size], |
|
scale=np.sqrt(variance)).astype(dtype)) |
|
|
|
|
|
for t in xrange(1, num_timesteps+1): |
|
if transition_type == models.ROUND_TRANSITION: |
|
loc = np.round(states[-1]) |
|
elif transition_type == models.STANDARD_TRANSITION: |
|
loc = states[-1] |
|
new_state = np.random.normal(size=[state_size], |
|
loc=loc, |
|
scale=np.sqrt(variance)) |
|
states.append(new_state.astype(dtype)) |
|
if t % steps_per_obs == 0: |
|
if fixed_observation is None: |
|
if observation_type == models.SQUARED_OBSERVATION: |
|
loc = np.square(states[-1]) |
|
elif observation_type == models.ABS_OBSERVATION: |
|
loc = np.abs(states[-1]) |
|
elif observation_type == models.STANDARD_OBSERVATION: |
|
loc = states[-1] |
|
new_obs = np.random.normal(size=[state_size], |
|
loc=loc, |
|
scale=np.sqrt(observation_variance)).astype(dtype) |
|
else: |
|
new_obs = np.ones([state_size])* fixed_observation |
|
|
|
observations.append(new_obs) |
|
yield states, observations |
|
|
|
dataset = tf.data.Dataset.from_generator( |
|
data_generator, |
|
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)), |
|
output_shapes=([num_timesteps+1, state_size], [num_obs, state_size])) |
|
dataset = dataset.repeat().batch(batch_size) |
|
|
|
def tile_batch(state, observation): |
|
state = tf.tile(state, [num_samples, 1, 1]) |
|
observation = tf.tile(observation, [num_samples, 1, 1]) |
|
return state, observation |
|
|
|
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024) |
|
return dataset |
|
|
|
|
|
def make_dataset(bs=None, |
|
state_size=1, |
|
num_timesteps=10, |
|
variance=1., |
|
prior_type="unimodal", |
|
bimodal_prior_weight=0.5, |
|
bimodal_prior_mean=1, |
|
transition_type=models.STANDARD_TRANSITION, |
|
fixed_observation=None, |
|
batch_size=4, |
|
num_samples=1, |
|
dtype='float32'): |
|
"""Creates a data generating process. |
|
|
|
Creates a tf.data.Dataset that provides batches of data. |
|
|
|
Args: |
|
bs: The parameters of the data generating process. If None, new bs are |
|
randomly generated. |
|
state_size: The dimension of the state space of the process. |
|
num_timesteps: The length of the state sequences in the process. |
|
variance: The variance of the normal distributions used at each timestep. |
|
batch_size: The number of trajectories to include in each batch. |
|
num_samples: The number of replicas of each trajectory to include in each |
|
batch. |
|
Returns: |
|
bs: The true bs used to generate the data |
|
dataset: A tf.data.Dataset that can be iterated over. |
|
""" |
|
|
|
if bs is None: |
|
bs = [np.random.uniform(size=[state_size]).astype(dtype) for _ in xrange(num_timesteps)] |
|
tf.logging.info("data generating processs bs: %s", |
|
np.array(bs).reshape(num_timesteps)) |
|
|
|
|
|
def data_generator(): |
|
"""An infinite generator of latents and observations from the model.""" |
|
while True: |
|
states = [] |
|
if prior_type == "unimodal" or prior_type == "nonlinear": |
|
|
|
states.append(np.random.normal(size=[state_size], scale=np.sqrt(variance)).astype(dtype)) |
|
elif prior_type == "bimodal": |
|
if np.random.uniform() > bimodal_prior_weight: |
|
loc = bimodal_prior_mean |
|
else: |
|
loc = - bimodal_prior_mean |
|
states.append(np.random.normal(size=[state_size], |
|
loc=loc, |
|
scale=np.sqrt(variance) |
|
).astype(dtype)) |
|
|
|
for t in xrange(num_timesteps): |
|
if transition_type == models.ROUND_TRANSITION: |
|
loc = np.round(states[-1]) |
|
elif transition_type == models.STANDARD_TRANSITION: |
|
loc = states[-1] |
|
loc += bs[t] |
|
new_state = np.random.normal(size=[state_size], |
|
loc=loc, |
|
scale=np.sqrt(variance)).astype(dtype) |
|
states.append(new_state) |
|
|
|
if fixed_observation is None: |
|
observation = states[-1] |
|
else: |
|
observation = np.ones_like(states[-1]) * fixed_observation |
|
yield np.array(states[:-1]), observation |
|
|
|
dataset = tf.data.Dataset.from_generator( |
|
data_generator, |
|
output_types=(tf.as_dtype(dtype), tf.as_dtype(dtype)), |
|
output_shapes=([num_timesteps, state_size], [state_size])) |
|
dataset = dataset.repeat().batch(batch_size) |
|
|
|
def tile_batch(state, observation): |
|
state = tf.tile(state, [num_samples, 1, 1]) |
|
observation = tf.tile(observation, [num_samples, 1]) |
|
return state, observation |
|
|
|
dataset = dataset.map(tile_batch, num_parallel_calls=12).prefetch(1024) |
|
return np.array(bs), dataset |
|
|