Spaces:
Running
Running
Pradeep Kumar
commited on
Delete pretrain_dynamic_dataloader.py
Browse files- pretrain_dynamic_dataloader.py +0 -223
pretrain_dynamic_dataloader.py
DELETED
@@ -1,223 +0,0 @@
|
|
1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
"""Dataset loader for the pre-training with dynamic sequence length."""
|
16 |
-
from typing import Optional, Tuple
|
17 |
-
|
18 |
-
import dataclasses
|
19 |
-
import tensorflow as tf, tf_keras
|
20 |
-
|
21 |
-
from official.core import config_definitions as cfg
|
22 |
-
from official.core import input_reader
|
23 |
-
from official.nlp.data import data_loader_factory
|
24 |
-
from official.nlp.data import pretrain_dataloader
|
25 |
-
|
26 |
-
|
27 |
-
@dataclasses.dataclass
|
28 |
-
class BertPretrainDataConfig(cfg.DataConfig):
|
29 |
-
"""Data config for BERT pretraining task (tasks/masked_lm)."""
|
30 |
-
input_path: str = ''
|
31 |
-
global_batch_size: int = 512
|
32 |
-
is_training: bool = True
|
33 |
-
seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
|
34 |
-
# TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
|
35 |
-
# tf.data service is disabled. Deprecate this flag once we always enable round
|
36 |
-
# robin tf.data service.
|
37 |
-
seq_bucket_window_scale: int = 8
|
38 |
-
use_next_sentence_label: bool = True
|
39 |
-
use_position_id: bool = False
|
40 |
-
deterministic: bool = False
|
41 |
-
enable_tf_data_service: bool = False
|
42 |
-
enable_round_robin_tf_data_service: bool = False
|
43 |
-
tf_data_service_job_name: str = 'bert_pretrain'
|
44 |
-
use_v2_feature_names: bool = False
|
45 |
-
|
46 |
-
|
47 |
-
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
|
48 |
-
class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
|
49 |
-
"""Dataset loader for bert-style pretraining with dynamic sequenece length.
|
50 |
-
|
51 |
-
Bucketizes the input id features by the seq_bucket_lengths and features are
|
52 |
-
padded to the bucket boundaries. The mask features are usually short than
|
53 |
-
input id features and can also be dynamic. We require the mask feature lengths
|
54 |
-
within a bucket must be the same. For example, with [128, 256] buckets,
|
55 |
-
the mask features for bucket 128 should always have the length as X and
|
56 |
-
features for bucket 256 should always have the length as Y.
|
57 |
-
|
58 |
-
The dataloader does not filter out empty masks. Make sure to handle this
|
59 |
-
in the model.
|
60 |
-
"""
|
61 |
-
|
62 |
-
def __init__(self, params):
|
63 |
-
self._params = params
|
64 |
-
if len(params.seq_bucket_lengths) < 1:
|
65 |
-
raise ValueError('The seq_bucket_lengths cannot be empty.')
|
66 |
-
self._seq_bucket_lengths = params.seq_bucket_lengths
|
67 |
-
self._seq_bucket_window_scale = params.seq_bucket_window_scale
|
68 |
-
self._global_batch_size = params.global_batch_size
|
69 |
-
self._use_next_sentence_label = params.use_next_sentence_label
|
70 |
-
self._use_position_id = params.use_position_id
|
71 |
-
self._drop_remainder = params.drop_remainder
|
72 |
-
self._enable_tf_data_service = params.enable_tf_data_service
|
73 |
-
self._enable_round_robin_tf_data_service = (
|
74 |
-
params.enable_round_robin_tf_data_service)
|
75 |
-
self._mask_keys = [
|
76 |
-
'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
|
77 |
-
]
|
78 |
-
|
79 |
-
def _decode(self, record: tf.Tensor):
|
80 |
-
"""Decodes a serialized tf.Example."""
|
81 |
-
name_to_features = {
|
82 |
-
'input_mask': tf.io.VarLenFeature(tf.int64),
|
83 |
-
'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
|
84 |
-
'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
|
85 |
-
'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
|
86 |
-
}
|
87 |
-
if self._params.use_v2_feature_names:
|
88 |
-
input_ids_key = 'input_word_ids'
|
89 |
-
segment_key = 'input_type_ids'
|
90 |
-
name_to_features.update({
|
91 |
-
input_ids_key: tf.io.VarLenFeature(tf.int64),
|
92 |
-
segment_key: tf.io.VarLenFeature(tf.int64),
|
93 |
-
})
|
94 |
-
else:
|
95 |
-
input_ids_key = 'input_ids'
|
96 |
-
segment_key = 'segment_ids'
|
97 |
-
name_to_features.update({
|
98 |
-
input_ids_key: tf.io.VarLenFeature(tf.int64),
|
99 |
-
segment_key: tf.io.VarLenFeature(tf.int64),
|
100 |
-
})
|
101 |
-
if self._use_next_sentence_label:
|
102 |
-
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
103 |
-
tf.int64)
|
104 |
-
dynamic_keys = [input_ids_key, 'input_mask', segment_key]
|
105 |
-
if self._use_position_id:
|
106 |
-
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
|
107 |
-
dynamic_keys.append('position_ids')
|
108 |
-
|
109 |
-
example = tf.io.parse_single_example(record, name_to_features)
|
110 |
-
for key in dynamic_keys + self._mask_keys:
|
111 |
-
example[key] = tf.sparse.to_dense(example[key])
|
112 |
-
|
113 |
-
# Truncate padded data after the first non pad in the
|
114 |
-
# sequence length dimension.
|
115 |
-
# Pad before the first non pad from the back should not be removed.
|
116 |
-
mask = tf.math.greater(
|
117 |
-
tf.math.cumsum(example[input_ids_key], reverse=True), 0)
|
118 |
-
for key in dynamic_keys:
|
119 |
-
example[key] = tf.boolean_mask(example[key], mask)
|
120 |
-
|
121 |
-
# masked_lm_ids should be 0 padded.
|
122 |
-
# Change mask features to -1 padding so that we can differentiate
|
123 |
-
# padding from data or from bucketizing.
|
124 |
-
mask = tf.math.not_equal(example['masked_lm_ids'], 0)
|
125 |
-
example['masked_lm_ids'] = tf.where(
|
126 |
-
mask, example['masked_lm_ids'],
|
127 |
-
-tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
|
128 |
-
|
129 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
130 |
-
# So cast all int64 to int32.
|
131 |
-
# tf.data service uses dataset graph fingerprint to distinguish input
|
132 |
-
# pipeline jobs, thus we sort the keys here to make sure they are generated
|
133 |
-
# in a deterministic order each time the dataset function is traced.
|
134 |
-
for name in sorted(list(example.keys())):
|
135 |
-
t = example[name]
|
136 |
-
if t.dtype == tf.int64:
|
137 |
-
t = tf.cast(t, tf.int32)
|
138 |
-
example[name] = t
|
139 |
-
|
140 |
-
return example
|
141 |
-
|
142 |
-
def _bucketize_and_batch(
|
143 |
-
self,
|
144 |
-
dataset,
|
145 |
-
input_context: Optional[tf.distribute.InputContext] = None):
|
146 |
-
"""Bucketize by sequence length and batch the datasets."""
|
147 |
-
per_replica_batch_size = input_context.get_per_replica_batch_size(
|
148 |
-
self._global_batch_size) if input_context else self._global_batch_size
|
149 |
-
|
150 |
-
def element_length_func(example, seq_len_dim):
|
151 |
-
return tf.shape(example['input_word_ids'])[seq_len_dim]
|
152 |
-
|
153 |
-
bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
|
154 |
-
bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
|
155 |
-
|
156 |
-
# Bucketize and batch the dataset with per replica batch size first.
|
157 |
-
dataset = dataset.apply(
|
158 |
-
tf.data.experimental.bucket_by_sequence_length(
|
159 |
-
lambda example: tf.cast(element_length_func(example, 0), tf.int32),
|
160 |
-
bucket_boundaries,
|
161 |
-
bucket_batch_sizes,
|
162 |
-
pad_to_bucket_boundary=True,
|
163 |
-
drop_remainder=self._drop_remainder))
|
164 |
-
if input_context:
|
165 |
-
window_size = input_context.num_replicas_in_sync
|
166 |
-
if self._enable_tf_data_service and (
|
167 |
-
not self._enable_round_robin_tf_data_service):
|
168 |
-
# If tf.data service is enabled but round-robin behavior is not enabled,
|
169 |
-
# different TPU workers may fetch data from one tf.data service worker
|
170 |
-
# in different speed. We set the window size to be
|
171 |
-
# `seq_bucket_window_scale` larger to leave buffer if some workers are
|
172 |
-
# fetching data faster than others, so all the data within the same
|
173 |
-
# global batch can still have more chances to be in the same bucket.
|
174 |
-
window_size *= self._seq_bucket_window_scale
|
175 |
-
|
176 |
-
# Group `num_replicas_in_sync` batches from same bucket together, so all
|
177 |
-
# replicas can get the same sequence length for one global step.
|
178 |
-
dataset = dataset.apply(
|
179 |
-
tf.data.experimental.group_by_window(
|
180 |
-
key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
|
181 |
-
element_length_func(example, 1), tf.int64),
|
182 |
-
reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
|
183 |
-
window_size=window_size))
|
184 |
-
dataset = dataset.flat_map(lambda x: x)
|
185 |
-
|
186 |
-
def _remove_pads_from_bucketize(features):
|
187 |
-
# All mask features must have the same effective length.
|
188 |
-
# The real masked ids padding token is -1 and 0 comes from
|
189 |
-
# bucket_by_sequence_length.
|
190 |
-
mask = tf.math.not_equal(features['masked_lm_ids'], 0)
|
191 |
-
|
192 |
-
mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
|
193 |
-
normalized = tf.cast(
|
194 |
-
mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
|
195 |
-
assert_op = tf.debugging.assert_equal(
|
196 |
-
tf.math.reduce_sum(normalized), per_replica_batch_size,
|
197 |
-
'Number of non padded mask tokens is not the same for each example '
|
198 |
-
'in the same sequence length.')
|
199 |
-
with tf.control_dependencies([assert_op]):
|
200 |
-
for key in self._mask_keys:
|
201 |
-
features[key] = tf.reshape(
|
202 |
-
tf.boolean_mask(
|
203 |
-
features[key], mask), [per_replica_batch_size, -1])
|
204 |
-
# Revert masked_lm_ids to be 0-padded.
|
205 |
-
mask = tf.math.not_equal(features['masked_lm_ids'], -1)
|
206 |
-
features['masked_lm_ids'] = tf.where(
|
207 |
-
mask, features['masked_lm_ids'],
|
208 |
-
tf.zeros(
|
209 |
-
tf.shape(features['masked_lm_ids']),
|
210 |
-
dtype=features['masked_lm_ids'].dtype))
|
211 |
-
return features
|
212 |
-
|
213 |
-
dataset = dataset.map(_remove_pads_from_bucketize)
|
214 |
-
return dataset
|
215 |
-
|
216 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
217 |
-
"""Returns a tf.dataset.Dataset."""
|
218 |
-
reader = input_reader.InputReader(
|
219 |
-
params=self._params,
|
220 |
-
decoder_fn=self._decode,
|
221 |
-
parser_fn=self._parse,
|
222 |
-
transform_and_batch_fn=self._bucketize_and_batch)
|
223 |
-
return reader.read(input_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|