|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Packed Sequence Op.""" |
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Optional, List, Union |
|
|
|
import tensorflow as tf |
|
|
|
AUTOTUNE = tf.data.experimental.AUTOTUNE |
|
|
|
|
|
def pack_dataset(dataset: tf.data.Dataset, |
|
key2length: Union[int, Dict[str, int]], |
|
keys: Optional[List[str]] = None) -> tf.data.Dataset: |
|
"""Creates a 'packed' version of a dataset on-the-fly. |
|
|
|
Adapted from the mesh-tf implementation. |
|
This is meant to replace the irritation of having to create a separate |
|
"packed" version of a dataset to train efficiently on TPU. |
|
Each example in the output dataset represents several examples in the |
|
input dataset. |
|
For each key in the input dataset, two additional keys are created: |
|
<key>_seg: an int32 tensor identifying the parts |
|
representing the original example. |
|
<key>_pos: an int32 tensor identifying the position within the original |
|
example. |
|
Example: |
|
Two input examples get combined to form an output example. |
|
The input examples are: |
|
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0]} |
|
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1]} |
|
The output example is: |
|
{ |
|
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0] |
|
"inputs_seg": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0] |
|
"inputs_pos": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0] |
|
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0] |
|
"targets_seg": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0] |
|
"targets_pos": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0] |
|
} |
|
0 represents padding in both the inputs and the outputs. |
|
Sequences in the incoming examples are truncated to length "length", and the |
|
sequences in the output examples all have fixed (padded) length "length". |
|
Args: |
|
dataset: a tf.data.Dataset |
|
key2length: an integer, or a dict from feature-key to integer |
|
keys: a list of strings (e.g. ["inputs", "targets"]) |
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
shapes = tf.nest.map_structure(lambda spec: spec.shape, dataset.element_spec) |
|
if keys is None: |
|
keys = list(shapes.keys()) |
|
for k in keys: |
|
if k not in shapes: |
|
raise ValueError(f"""Key {k} not found in dataset. Available keys are |
|
{shapes.keys()}""") |
|
if not shapes[k].is_compatible_with(tf.TensorShape([None])): |
|
raise ValueError('Tensors to be packed must be one-dimensional.') |
|
|
|
|
|
if isinstance(key2length, int): |
|
key2length = {k: key2length for k in keys} |
|
else: |
|
key2length = dict(key2length) |
|
for k in keys: |
|
for suffix in ['_seg', '_pos']: |
|
key2length[k + suffix] = key2length[k] |
|
|
|
|
|
dataset = dataset.map( |
|
lambda x: {k: x[k][:key2length[k]] for k in keys}, |
|
num_parallel_calls=AUTOTUNE) |
|
|
|
|
|
batch_size = max(key2length.values()) |
|
dataset = dataset.padded_batch( |
|
batch_size, padded_shapes={k: [-1] for k in keys}) |
|
dataset = _pack_with_tf_ops(dataset, keys, key2length) |
|
|
|
|
|
def my_fn(x): |
|
return {k: tf.reshape(v, [key2length[k]]) for k, v in x.items()} |
|
|
|
return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) |
|
|
|
|
|
def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], |
|
key2length: Dict[str, int]) -> tf.data.Dataset: |
|
"""Helper-function for packing a dataset which has already been batched. |
|
Helper for pack_dataset() Uses tf.while_loop. |
|
Args: |
|
dataset: a dataset containing padded batches of examples. |
|
keys: a list of strings |
|
key2length: an dict from feature-key to integer |
|
Returns: |
|
a dataset. |
|
""" |
|
empty_example = {} |
|
for k in keys: |
|
empty_example[k] = tf.zeros([0], dtype=tf.int32) |
|
empty_example[k + '_pos'] = tf.zeros([0], dtype=tf.int32) |
|
keys_etc = empty_example.keys() |
|
|
|
def write_packed_example(partial, outputs): |
|
new_partial = empty_example.copy() |
|
new_outputs = {} |
|
for k in keys_etc: |
|
new_outputs[k] = outputs[k].write( |
|
outputs[k].size(), |
|
tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) |
|
return new_partial, new_outputs |
|
|
|
def map_fn(x): |
|
"""Internal function to flat_map over. |
|
Consumes a batch of input examples and produces a variable number of output |
|
examples. |
|
Args: |
|
x: a single example |
|
Returns: |
|
a tf.data.Dataset |
|
""" |
|
partial = empty_example.copy() |
|
i = tf.zeros([], dtype=tf.int32) |
|
dynamic_batch_size = tf.shape(x[keys[0]])[0] |
|
outputs = {} |
|
for k in keys: |
|
outputs[k] = tf.TensorArray( |
|
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) |
|
outputs[k + '_pos'] = tf.TensorArray( |
|
tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) |
|
|
|
def body_fn(i, partial, outputs): |
|
"""Body function for while_loop. |
|
Args: |
|
i: integer scalar |
|
partial: dictionary of Tensor (partially-constructed example) |
|
outputs: dictionary of TensorArray |
|
Returns: |
|
A triple containing the new values of the inputs. |
|
""" |
|
can_append = True |
|
one_example = {} |
|
for k in keys: |
|
val = tf.cast(x[k][i], tf.int32) |
|
val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] |
|
one_example[k] = val |
|
for k in keys: |
|
can_append = tf.logical_and( |
|
can_append, |
|
tf.less_equal( |
|
tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) |
|
|
|
def false_fn(): |
|
return write_packed_example(partial, outputs) |
|
|
|
def true_fn(): |
|
return partial, outputs |
|
|
|
partial, outputs = tf.cond(can_append, true_fn, false_fn) |
|
new_partial = {} |
|
for k in keys: |
|
new_seq = one_example[k][:key2length[k]] |
|
new_seq_len = tf.size(new_seq) |
|
new_partial[k] = tf.concat([partial[k], new_seq], 0) |
|
new_partial[k + '_pos'] = tf.concat( |
|
[partial[k + '_pos'], |
|
tf.range(new_seq_len)], 0) |
|
partial = new_partial |
|
return i + 1, partial, outputs |
|
|
|
|
|
i, partial, outputs = tf.while_loop( |
|
cond=lambda *_: True, |
|
body=body_fn, |
|
loop_vars=(i, partial, outputs), |
|
shape_invariants=( |
|
tf.TensorShape([]), |
|
{k: tf.TensorShape([None]) for k in keys_etc}, |
|
{k: tf.TensorShape(None) for k in keys_etc}, |
|
), |
|
maximum_iterations=dynamic_batch_size) |
|
_, outputs = write_packed_example(partial, outputs) |
|
packed = {k: outputs[k].stack() for k in keys_etc} |
|
for k in keys: |
|
packed[k + '_seg'] = ( |
|
tf.cumsum( |
|
tf.cast(tf.equal(packed[k + '_pos'], 0), tf.int32), axis=1) * |
|
tf.cast(tf.not_equal(packed[k], 0), tf.int32)) |
|
return packed |
|
|
|
dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) |
|
return dataset.unbatch() |
|
|