|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TensorFlow Datasets as data source for big_vision.""" |
|
import functools |
|
|
|
import big_vision.datasets.core as ds_core |
|
import jax |
|
import numpy as np |
|
import overrides |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
class DataSource(ds_core.DataSource): |
|
"""Use TFDS as a data source.""" |
|
|
|
def __init__(self, name, split, data_dir=None, skip_decode=("image",)): |
|
self.builder = _get_builder(name, data_dir) |
|
self.split = split |
|
|
|
process_splits = tfds.even_splits(split, jax.process_count()) |
|
self.process_split = process_splits[jax.process_index()] |
|
self.skip_decode = skip_decode |
|
|
|
@overrides.overrides |
|
def get_tfdata( |
|
self, ordered=False, *, process_split=True, allow_cache=True, **kw): |
|
|
|
|
|
|
|
return (_cached_get_dataset if allow_cache else _get_dataset)( |
|
self.builder, self.skip_decode, |
|
split=self.process_split if process_split else self.split, |
|
shuffle_files=not ordered, |
|
**kw) |
|
|
|
@property |
|
@overrides.overrides |
|
def total_examples(self): |
|
return self.builder.info.splits[self.split].num_examples |
|
|
|
@overrides.overrides |
|
def num_examples_per_process(self): |
|
splits = tfds.even_splits(self.split, jax.process_count()) |
|
return [self.builder.info.splits[s].num_examples for s in splits] |
|
|
|
|
|
@functools.cache |
|
def _get_builder(dataset, data_dir): |
|
if dataset == "from_data_dir": |
|
return tfds.builder_from_directory(data_dir) |
|
else: |
|
return tfds.builder(dataset, data_dir=data_dir, try_gcs=True) |
|
|
|
|
|
|
|
|
|
def _get_dataset(builder, skip_decode, **kw): |
|
"""Returns a tf.data to be used.""" |
|
rckw = {k: kw.pop(k) for k in ("shuffle_seed",) if k in kw} |
|
ds = builder.as_dataset( |
|
read_config=tfds.ReadConfig( |
|
skip_prefetch=True, |
|
try_autocache=False, |
|
add_tfds_id=True, |
|
**rckw, |
|
), |
|
decoders={ |
|
f: tfds.decode.SkipDecoding() |
|
for f in skip_decode if f in builder.info.features |
|
}, |
|
**kw) |
|
|
|
def _hash_tfds_id(example): |
|
id_ = tf.strings.to_hash_bucket_strong( |
|
example["tfds_id"], |
|
np.iinfo(np.uint32).max, |
|
[3714561454027272724, 8800639020734831960]) |
|
example["_id"] = tf.bitcast(id_, tf.int32)[0] |
|
return example |
|
|
|
return ds.map(_hash_tfds_id) |
|
_cached_get_dataset = functools.cache(_get_dataset) |
|
|