Pradeep Kumar commited on
Commit
4905638
·
verified ·
1 Parent(s): a39862f

Delete pretrain_dynamic_dataloader.py

Browse files
Files changed (1) hide show
  1. pretrain_dynamic_dataloader.py +0 -223
pretrain_dynamic_dataloader.py DELETED
@@ -1,223 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Dataset loader for the pre-training with dynamic sequence length."""
16
- from typing import Optional, Tuple
17
-
18
- import dataclasses
19
- import tensorflow as tf, tf_keras
20
-
21
- from official.core import config_definitions as cfg
22
- from official.core import input_reader
23
- from official.nlp.data import data_loader_factory
24
- from official.nlp.data import pretrain_dataloader
25
-
26
-
27
- @dataclasses.dataclass
28
- class BertPretrainDataConfig(cfg.DataConfig):
29
- """Data config for BERT pretraining task (tasks/masked_lm)."""
30
- input_path: str = ''
31
- global_batch_size: int = 512
32
- is_training: bool = True
33
- seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
34
- # TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
35
- # tf.data service is disabled. Deprecate this flag once we always enable round
36
- # robin tf.data service.
37
- seq_bucket_window_scale: int = 8
38
- use_next_sentence_label: bool = True
39
- use_position_id: bool = False
40
- deterministic: bool = False
41
- enable_tf_data_service: bool = False
42
- enable_round_robin_tf_data_service: bool = False
43
- tf_data_service_job_name: str = 'bert_pretrain'
44
- use_v2_feature_names: bool = False
45
-
46
-
47
- @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
48
- class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
49
- """Dataset loader for bert-style pretraining with dynamic sequenece length.
50
-
51
- Bucketizes the input id features by the seq_bucket_lengths and features are
52
- padded to the bucket boundaries. The mask features are usually short than
53
- input id features and can also be dynamic. We require the mask feature lengths
54
- within a bucket must be the same. For example, with [128, 256] buckets,
55
- the mask features for bucket 128 should always have the length as X and
56
- features for bucket 256 should always have the length as Y.
57
-
58
- The dataloader does not filter out empty masks. Make sure to handle this
59
- in the model.
60
- """
61
-
62
- def __init__(self, params):
63
- self._params = params
64
- if len(params.seq_bucket_lengths) < 1:
65
- raise ValueError('The seq_bucket_lengths cannot be empty.')
66
- self._seq_bucket_lengths = params.seq_bucket_lengths
67
- self._seq_bucket_window_scale = params.seq_bucket_window_scale
68
- self._global_batch_size = params.global_batch_size
69
- self._use_next_sentence_label = params.use_next_sentence_label
70
- self._use_position_id = params.use_position_id
71
- self._drop_remainder = params.drop_remainder
72
- self._enable_tf_data_service = params.enable_tf_data_service
73
- self._enable_round_robin_tf_data_service = (
74
- params.enable_round_robin_tf_data_service)
75
- self._mask_keys = [
76
- 'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
77
- ]
78
-
79
- def _decode(self, record: tf.Tensor):
80
- """Decodes a serialized tf.Example."""
81
- name_to_features = {
82
- 'input_mask': tf.io.VarLenFeature(tf.int64),
83
- 'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
84
- 'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
85
- 'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
86
- }
87
- if self._params.use_v2_feature_names:
88
- input_ids_key = 'input_word_ids'
89
- segment_key = 'input_type_ids'
90
- name_to_features.update({
91
- input_ids_key: tf.io.VarLenFeature(tf.int64),
92
- segment_key: tf.io.VarLenFeature(tf.int64),
93
- })
94
- else:
95
- input_ids_key = 'input_ids'
96
- segment_key = 'segment_ids'
97
- name_to_features.update({
98
- input_ids_key: tf.io.VarLenFeature(tf.int64),
99
- segment_key: tf.io.VarLenFeature(tf.int64),
100
- })
101
- if self._use_next_sentence_label:
102
- name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
103
- tf.int64)
104
- dynamic_keys = [input_ids_key, 'input_mask', segment_key]
105
- if self._use_position_id:
106
- name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
107
- dynamic_keys.append('position_ids')
108
-
109
- example = tf.io.parse_single_example(record, name_to_features)
110
- for key in dynamic_keys + self._mask_keys:
111
- example[key] = tf.sparse.to_dense(example[key])
112
-
113
- # Truncate padded data after the first non pad in the
114
- # sequence length dimension.
115
- # Pad before the first non pad from the back should not be removed.
116
- mask = tf.math.greater(
117
- tf.math.cumsum(example[input_ids_key], reverse=True), 0)
118
- for key in dynamic_keys:
119
- example[key] = tf.boolean_mask(example[key], mask)
120
-
121
- # masked_lm_ids should be 0 padded.
122
- # Change mask features to -1 padding so that we can differentiate
123
- # padding from data or from bucketizing.
124
- mask = tf.math.not_equal(example['masked_lm_ids'], 0)
125
- example['masked_lm_ids'] = tf.where(
126
- mask, example['masked_lm_ids'],
127
- -tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
128
-
129
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
130
- # So cast all int64 to int32.
131
- # tf.data service uses dataset graph fingerprint to distinguish input
132
- # pipeline jobs, thus we sort the keys here to make sure they are generated
133
- # in a deterministic order each time the dataset function is traced.
134
- for name in sorted(list(example.keys())):
135
- t = example[name]
136
- if t.dtype == tf.int64:
137
- t = tf.cast(t, tf.int32)
138
- example[name] = t
139
-
140
- return example
141
-
142
- def _bucketize_and_batch(
143
- self,
144
- dataset,
145
- input_context: Optional[tf.distribute.InputContext] = None):
146
- """Bucketize by sequence length and batch the datasets."""
147
- per_replica_batch_size = input_context.get_per_replica_batch_size(
148
- self._global_batch_size) if input_context else self._global_batch_size
149
-
150
- def element_length_func(example, seq_len_dim):
151
- return tf.shape(example['input_word_ids'])[seq_len_dim]
152
-
153
- bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
154
- bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
155
-
156
- # Bucketize and batch the dataset with per replica batch size first.
157
- dataset = dataset.apply(
158
- tf.data.experimental.bucket_by_sequence_length(
159
- lambda example: tf.cast(element_length_func(example, 0), tf.int32),
160
- bucket_boundaries,
161
- bucket_batch_sizes,
162
- pad_to_bucket_boundary=True,
163
- drop_remainder=self._drop_remainder))
164
- if input_context:
165
- window_size = input_context.num_replicas_in_sync
166
- if self._enable_tf_data_service and (
167
- not self._enable_round_robin_tf_data_service):
168
- # If tf.data service is enabled but round-robin behavior is not enabled,
169
- # different TPU workers may fetch data from one tf.data service worker
170
- # in different speed. We set the window size to be
171
- # `seq_bucket_window_scale` larger to leave buffer if some workers are
172
- # fetching data faster than others, so all the data within the same
173
- # global batch can still have more chances to be in the same bucket.
174
- window_size *= self._seq_bucket_window_scale
175
-
176
- # Group `num_replicas_in_sync` batches from same bucket together, so all
177
- # replicas can get the same sequence length for one global step.
178
- dataset = dataset.apply(
179
- tf.data.experimental.group_by_window(
180
- key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
181
- element_length_func(example, 1), tf.int64),
182
- reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
183
- window_size=window_size))
184
- dataset = dataset.flat_map(lambda x: x)
185
-
186
- def _remove_pads_from_bucketize(features):
187
- # All mask features must have the same effective length.
188
- # The real masked ids padding token is -1 and 0 comes from
189
- # bucket_by_sequence_length.
190
- mask = tf.math.not_equal(features['masked_lm_ids'], 0)
191
-
192
- mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
193
- normalized = tf.cast(
194
- mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
195
- assert_op = tf.debugging.assert_equal(
196
- tf.math.reduce_sum(normalized), per_replica_batch_size,
197
- 'Number of non padded mask tokens is not the same for each example '
198
- 'in the same sequence length.')
199
- with tf.control_dependencies([assert_op]):
200
- for key in self._mask_keys:
201
- features[key] = tf.reshape(
202
- tf.boolean_mask(
203
- features[key], mask), [per_replica_batch_size, -1])
204
- # Revert masked_lm_ids to be 0-padded.
205
- mask = tf.math.not_equal(features['masked_lm_ids'], -1)
206
- features['masked_lm_ids'] = tf.where(
207
- mask, features['masked_lm_ids'],
208
- tf.zeros(
209
- tf.shape(features['masked_lm_ids']),
210
- dtype=features['masked_lm_ids'].dtype))
211
- return features
212
-
213
- dataset = dataset.map(_remove_pads_from_bucketize)
214
- return dataset
215
-
216
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
217
- """Returns a tf.dataset.Dataset."""
218
- reader = input_reader.InputReader(
219
- params=self._params,
220
- decoder_fn=self._decode,
221
- parser_fn=self._parse,
222
- transform_and_batch_fn=self._bucketize_and_batch)
223
- return reader.read(input_context)