pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities used in evaluators."""
import math
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
def get_jax_process_dataset(dataset, split, global_batch_size, pp_fn,
dataset_dir=None, cache=True, add_tfds_id=False):
"""Returns dataset to be processed by current jax host.
The dataset is sharded and padded with zeros such that all processes
have equal number of batches. The first 2 dimensions of the dataset
elements are: [local_device_count, device_batch_size].
Args:
dataset: dataset name.
split: dataset split.
global_batch_size: batch size to be process per iteration on the dataset.
pp_fn: preprocessing function to apply per example.
dataset_dir: path for tfds to find the prepared data.
cache: whether to cache the dataset after batching.
add_tfds_id: whether to add the unique `tfds_id` string to each example.
"""
assert global_batch_size % jax.device_count() == 0
total_examples = tfds.load(
dataset, split=split, data_dir=dataset_dir).cardinality()
num_batches = math.ceil(total_examples / global_batch_size)
process_split = tfds.even_splits(
split, n=jax.process_count(), drop_remainder=False)[jax.process_index()]
data = tfds.load(
dataset,
split=process_split,
data_dir=dataset_dir,
read_config=tfds.ReadConfig(add_tfds_id=add_tfds_id)).map(pp_fn)
pad_data = tf.data.Dataset.from_tensors(
jax.tree_map(lambda x: tf.zeros(x.shape, x.dtype), data.element_spec)
).repeat()
data = data.concatenate(pad_data)
data = data.batch(global_batch_size // jax.device_count())
data = data.batch(jax.local_device_count())
data = data.take(num_batches)
if cache:
# Eval datasets are often used many times and caching the dataset after
# batching allows one to have the buffers ready to be used and not have
# to wait for preprocessing to be done over and over.
data = data.cache()
return data