|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Common utilities used in evaluators.""" |
|
import math |
|
import jax |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn, |
|
dataset_dir=None, cache=True, add_tfds_id=False): |
|
"""Returns dataset to be processed by current jax host. |
|
|
|
The dataset is sharded and padded with zeros such that all processes |
|
have equal number of batches. The first 2 dimensions of the dataset |
|
elements are: [local_device_count, device_batch_size]. |
|
|
|
Args: |
|
dataset: dataset name. |
|
split: dataset split. |
|
global_batch_size: batch size to be process per iteration on the dataset. |
|
pp_fn: preprocessing function to apply per example. |
|
dataset_dir: path for tfds to find the prepared data. |
|
cache: whether to cache the dataset after batching. |
|
add_tfds_id: whether to add the unique `tfds_id` string to each example. |
|
""" |
|
assert global_batch_size % jax.device_count() == 0 |
|
total_examples = tfds.load( |
|
dataset, split=split, data_dir=dataset_dir).cardinality() |
|
num_batches = math.ceil(total_examples / global_batch_size) |
|
|
|
process_split = tfds.even_splits( |
|
split, n=jax.process_count(), drop_remainder=False)[jax.process_index()] |
|
data = tfds.load( |
|
dataset, |
|
split=process_split, |
|
data_dir=dataset_dir, |
|
read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn) |
|
pad_data = tf.data.Dataset.from_tensors( |
|
jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec) |
|
).repeat() |
|
|
|
data = data.concatenate(pad_data) |
|
data = data.batch(global_batch_size // jax.device_count()) |
|
data = data.batch(jax.local_device_count()) |
|
data = data.take(num_batches) |
|
if cache: |
|
|
|
|
|
|
|
data = data.cache() |
|
return data |
|
|