Pradeep Kumar commited on
Commit
0de7ffc
·
verified ·
1 Parent(s): f891af0

Delete pretrain_dataloader.py

Browse files
Files changed (1) hide show
  1. pretrain_dataloader.py +0 -589
pretrain_dataloader.py DELETED
@@ -1,589 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Loads dataset for the BERT pretraining task."""
16
- import dataclasses
17
- from typing import Mapping, Optional
18
-
19
- from absl import logging
20
-
21
- import numpy as np
22
- import tensorflow as tf, tf_keras
23
- from official.common import dataset_fn
24
- from official.core import config_definitions as cfg
25
- from official.core import input_reader
26
- from official.nlp.data import data_loader
27
- from official.nlp.data import data_loader_factory
28
-
29
-
30
- @dataclasses.dataclass
31
- class BertPretrainDataConfig(cfg.DataConfig):
32
- """Data config for BERT pretraining task (tasks/masked_lm)."""
33
- input_path: str = ''
34
- global_batch_size: int = 512
35
- is_training: bool = True
36
- seq_length: int = 512
37
- max_predictions_per_seq: int = 76
38
- use_next_sentence_label: bool = True
39
- use_position_id: bool = False
40
- # Historically, BERT implementations take `input_ids` and `segment_ids` as
41
- # feature names. Inside the TF Model Garden implementation, the Keras model
42
- # inputs are set as `input_word_ids` and `input_type_ids`. When
43
- # v2_feature_names is True, the data loader assumes the tf.Examples use
44
- # `input_word_ids` and `input_type_ids` as keys.
45
- use_v2_feature_names: bool = False
46
- file_type: str = 'tfrecord'
47
-
48
-
49
- @data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
50
- class BertPretrainDataLoader(data_loader.DataLoader):
51
- """A class to load dataset for bert pretraining task."""
52
-
53
- def __init__(self, params):
54
- """Inits `BertPretrainDataLoader` class.
55
-
56
- Args:
57
- params: A `BertPretrainDataConfig` object.
58
- """
59
- self._params = params
60
- self._seq_length = params.seq_length
61
- self._max_predictions_per_seq = params.max_predictions_per_seq
62
- self._use_next_sentence_label = params.use_next_sentence_label
63
- self._use_position_id = params.use_position_id
64
-
65
- def _name_to_features(self):
66
- name_to_features = {
67
- 'input_mask':
68
- tf.io.FixedLenFeature([self._seq_length], tf.int64),
69
- 'masked_lm_positions':
70
- tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
71
- 'masked_lm_ids':
72
- tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
73
- 'masked_lm_weights':
74
- tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
75
- }
76
- if self._params.use_v2_feature_names:
77
- name_to_features.update({
78
- 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
79
- 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
80
- })
81
- else:
82
- name_to_features.update({
83
- 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
84
- 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
85
- })
86
- if self._use_next_sentence_label:
87
- name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
88
- tf.int64)
89
- if self._use_position_id:
90
- name_to_features['position_ids'] = tf.io.FixedLenFeature(
91
- [self._seq_length], tf.int64)
92
- return name_to_features
93
-
94
- def _decode(self, record: tf.Tensor):
95
- """Decodes a serialized tf.Example."""
96
- name_to_features = self._name_to_features()
97
- example = tf.io.parse_single_example(record, name_to_features)
98
-
99
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
100
- # So cast all int64 to int32.
101
- for name in list(example.keys()):
102
- t = example[name]
103
- if t.dtype == tf.int64:
104
- t = tf.cast(t, tf.int32)
105
- example[name] = t
106
-
107
- return example
108
-
109
- def _parse(self, record: Mapping[str, tf.Tensor]):
110
- """Parses raw tensors into a dict of tensors to be consumed by the model."""
111
- x = {
112
- 'input_mask': record['input_mask'],
113
- 'masked_lm_positions': record['masked_lm_positions'],
114
- 'masked_lm_ids': record['masked_lm_ids'],
115
- 'masked_lm_weights': record['masked_lm_weights'],
116
- }
117
- if self._params.use_v2_feature_names:
118
- x['input_word_ids'] = record['input_word_ids']
119
- x['input_type_ids'] = record['input_type_ids']
120
- else:
121
- x['input_word_ids'] = record['input_ids']
122
- x['input_type_ids'] = record['segment_ids']
123
- if self._use_next_sentence_label:
124
- x['next_sentence_labels'] = record['next_sentence_labels']
125
- if self._use_position_id:
126
- x['position_ids'] = record['position_ids']
127
-
128
- return x
129
-
130
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
131
- """Returns a tf.dataset.Dataset."""
132
- reader = input_reader.InputReader(
133
- params=self._params,
134
- dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
135
- decoder_fn=self._decode,
136
- parser_fn=self._parse)
137
- return reader.read(input_context)
138
-
139
-
140
- @dataclasses.dataclass
141
- class XLNetPretrainDataConfig(cfg.DataConfig):
142
- """Data config for XLNet pretraining task.
143
-
144
- Attributes:
145
- input_path: See base class.
146
- global_batch_size: See base class.
147
- is_training: See base class.
148
- seq_length: The length of each sequence.
149
- max_predictions_per_seq: The number of predictions per sequence.
150
- reuse_length: The number of tokens in a previous segment to reuse. This
151
- should be the same value used during pretrain data creation.
152
- sample_strategy: The strategy used to sample factorization permutations.
153
- Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
154
- min_num_tokens: The minimum number of tokens to sample in a span. This is
155
- used when `sample_strategy` is 'token_span'.
156
- max_num_tokens: The maximum number of tokens to sample in a span. This is
157
- used when `sample_strategy` is 'token_span'.
158
- min_num_words: The minimum number of words to sample in a span. This is used
159
- when `sample_strategy` is 'word_span'.
160
- max_num_words: The maximum number of words to sample in a span. This is used
161
- when `sample_strategy` is 'word_span'.
162
- permutation_size: The length of the longest permutation. This can be set to
163
- `reuse_length`. This should NOT be greater than `reuse_length`, otherwise
164
- this may introduce data leaks.
165
- leak_ratio: The percentage of masked tokens that are leaked.
166
- segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
167
- segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
168
- """
169
- input_path: str = ''
170
- global_batch_size: int = 512
171
- is_training: bool = True
172
- seq_length: int = 512
173
- max_predictions_per_seq: int = 76
174
- reuse_length: int = 256
175
- sample_strategy: str = 'word_span'
176
- min_num_tokens: int = 1
177
- max_num_tokens: int = 5
178
- min_num_words: int = 1
179
- max_num_words: int = 5
180
- permutation_size: int = 256
181
- leak_ratio: float = 0.1
182
- segment_sep_id: int = 4
183
- segment_cls_id: int = 3
184
-
185
-
186
- @data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
187
- class XLNetPretrainDataLoader(data_loader.DataLoader):
188
- """A class to load dataset for xlnet pretraining task."""
189
-
190
- def __init__(self, params: XLNetPretrainDataConfig):
191
- """Inits `XLNetPretrainDataLoader` class.
192
-
193
- Args:
194
- params: A `XLNetPretrainDataConfig` object.
195
- """
196
- self._params = params
197
- self._seq_length = params.seq_length
198
- self._max_predictions_per_seq = params.max_predictions_per_seq
199
- self._reuse_length = params.reuse_length
200
- self._num_replicas_in_sync = None
201
- self._permutation_size = params.permutation_size
202
- self._sep_id = params.segment_sep_id
203
- self._cls_id = params.segment_cls_id
204
- self._sample_strategy = params.sample_strategy
205
- self._leak_ratio = params.leak_ratio
206
-
207
- def _decode(self, record: tf.Tensor):
208
- """Decodes a serialized tf.Example."""
209
- name_to_features = {
210
- 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
211
- 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
212
- 'boundary_indices': tf.io.VarLenFeature(tf.int64),
213
- }
214
- example = tf.io.parse_single_example(record, name_to_features)
215
-
216
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
217
- # So cast all int64 to int32.
218
- for name in list(example.keys()):
219
- t = example[name]
220
- if t.dtype == tf.int64:
221
- t = tf.cast(t, tf.int32)
222
- example[name] = t
223
-
224
- return example
225
-
226
- def _parse(self, record: Mapping[str, tf.Tensor]):
227
- """Parses raw tensors into a dict of tensors to be consumed by the model."""
228
- x = {}
229
-
230
- inputs = record['input_word_ids']
231
- x['input_type_ids'] = record['input_type_ids']
232
-
233
- if self._sample_strategy in ['whole_word', 'word_span']:
234
- boundary = tf.sparse.to_dense(record['boundary_indices'])
235
- else:
236
- boundary = None
237
-
238
- input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
239
-
240
- if self._reuse_length > 0:
241
- if self._permutation_size > self._reuse_length:
242
- logging.warning(
243
- '`permutation_size` is greater than `reuse_length` (%d > %d).'
244
- 'This may introduce data leakage.', self._permutation_size,
245
- self._reuse_length)
246
-
247
- # Enable the memory mechanism.
248
- # Permute the reuse and non-reuse segments separately.
249
- non_reuse_len = self._seq_length - self._reuse_length
250
- if not (self._reuse_length % self._permutation_size == 0 and
251
- non_reuse_len % self._permutation_size == 0):
252
- raise ValueError('`reuse_length` and `seq_length` should both be '
253
- 'a multiple of `permutation_size`.')
254
-
255
- # Creates permutation mask and target mask for the first reuse_len tokens.
256
- # The tokens in this part are reused from the last sequence.
257
- perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
258
- inputs=inputs[:self._reuse_length],
259
- input_mask=input_mask[:self._reuse_length])
260
-
261
- # Creates permutation mask and target mask for the rest of tokens in
262
- # current example, which are concatenation of two new segments.
263
- perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
264
- inputs[self._reuse_length:], input_mask[self._reuse_length:])
265
-
266
- perm_mask_0 = tf.concat([
267
- perm_mask_0,
268
- tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
269
- ],
270
- axis=1)
271
- perm_mask_1 = tf.concat([
272
- tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
273
- perm_mask_1
274
- ],
275
- axis=1)
276
- perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
277
- target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
278
- tokens = tf.concat([tokens_0, tokens_1], axis=0)
279
- masked_tokens = tf.concat([masked_0, masked_1], axis=0)
280
- else:
281
- # Disable the memory mechanism.
282
- if self._seq_length % self._permutation_size != 0:
283
- raise ValueError('`seq_length` should be a multiple of '
284
- '`permutation_size`.')
285
- # Permute the entire sequence together
286
- perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
287
- inputs=inputs, input_mask=input_mask)
288
- x['permutation_mask'] = tf.reshape(perm_mask,
289
- [self._seq_length, self._seq_length])
290
- x['input_word_ids'] = tokens
291
- x['masked_tokens'] = masked_tokens
292
-
293
- target = tokens
294
- if self._max_predictions_per_seq is not None:
295
- indices = tf.range(self._seq_length, dtype=tf.int32)
296
- bool_target_mask = tf.cast(target_mask, tf.bool)
297
- indices = tf.boolean_mask(indices, bool_target_mask)
298
-
299
- # account for extra padding due to CLS/SEP.
300
- actual_num_predict = tf.shape(indices)[0]
301
- pad_len = self._max_predictions_per_seq - actual_num_predict
302
-
303
- target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
304
- paddings = tf.zeros([pad_len, self._seq_length],
305
- dtype=target_mapping.dtype)
306
- target_mapping = tf.concat([target_mapping, paddings], axis=0)
307
- x['target_mapping'] = tf.reshape(
308
- target_mapping, [self._max_predictions_per_seq, self._seq_length])
309
-
310
- target = tf.boolean_mask(target, bool_target_mask)
311
- paddings = tf.zeros([pad_len], dtype=target.dtype)
312
- target = tf.concat([target, paddings], axis=0)
313
- x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
314
-
315
- target_mask = tf.concat([
316
- tf.ones([actual_num_predict], dtype=tf.int32),
317
- tf.zeros([pad_len], dtype=tf.int32)
318
- ],
319
- axis=0)
320
- x['target_mask'] = tf.reshape(target_mask,
321
- [self._max_predictions_per_seq])
322
- else:
323
- x['target'] = tf.reshape(target, [self._seq_length])
324
- x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
325
- return x
326
-
327
- def _index_pair_to_mask(self, begin_indices: tf.Tensor,
328
- end_indices: tf.Tensor,
329
- inputs: tf.Tensor) -> tf.Tensor:
330
- """Converts beginning and end indices into an actual mask."""
331
- non_func_mask = tf.logical_and(
332
- tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
333
- all_indices = tf.where(
334
- non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
335
- tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
336
- candidate_matrix = tf.cast(
337
- tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
338
- all_indices[None, :] < end_indices[:, None]), tf.float32)
339
- cumsum_matrix = tf.reshape(
340
- tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
341
- masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
342
- tf.float32)
343
- target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
344
- return tf.cast(target_mask, tf.bool)
345
-
346
- def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
347
- """Samples individual tokens as prediction targets."""
348
- all_indices = tf.range(self._seq_length, dtype=tf.int32)
349
- non_func_mask = tf.logical_and(
350
- tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
351
- non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
352
-
353
- masked_pos = tf.random.shuffle(non_func_indices)
354
- masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
355
-
356
- sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
357
- sparse_indices = tf.cast(sparse_indices, tf.int64)
358
-
359
- sparse_indices = tf.sparse.SparseTensor(
360
- sparse_indices,
361
- values=tf.ones_like(masked_pos),
362
- dense_shape=(1, self._seq_length))
363
-
364
- target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)
365
-
366
- return tf.squeeze(tf.cast(target_mask, tf.bool))
367
-
368
- def _whole_word_mask(self, inputs: tf.Tensor,
369
- boundary: tf.Tensor) -> tf.Tensor:
370
- """Samples whole words as prediction targets."""
371
- pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
372
- cand_pair_indices = tf.random.shuffle(
373
- pair_indices)[:self._max_predictions_per_seq]
374
- begin_indices = cand_pair_indices[:, 0]
375
- end_indices = cand_pair_indices[:, 1]
376
-
377
- return self._index_pair_to_mask(
378
- begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
379
-
380
- def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
381
- """Samples token spans as prediction targets."""
382
- min_num_tokens = self._params.min_num_tokens
383
- max_num_tokens = self._params.max_num_tokens
384
-
385
- mask_alpha = self._seq_length / self._max_predictions_per_seq
386
- round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
387
-
388
- # Sample span lengths from a zipf distribution
389
- span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
390
- probs = np.array([1.0 / (i + 1) for i in span_len_seq])
391
-
392
- probs /= np.sum(probs)
393
- logits = tf.constant(np.log(probs), dtype=tf.float32)
394
- span_lens = tf.random.categorical(
395
- logits=logits[None],
396
- num_samples=self._max_predictions_per_seq,
397
- dtype=tf.int32,
398
- )[0] + min_num_tokens
399
-
400
- # Sample the ratio [0.0, 1.0) of left context lengths
401
- span_lens_float = tf.cast(span_lens, tf.float32)
402
- left_ratio = tf.random.uniform(
403
- shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
404
- left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
405
- left_ctx_len = round_to_int(left_ctx_len)
406
-
407
- # Compute the offset from left start to the right end
408
- right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
409
-
410
- # Get the actual begin and end indices
411
- begin_indices = (
412
- tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
413
- end_indices = begin_indices + span_lens
414
-
415
- # Remove out of range indices
416
- valid_idx_mask = end_indices < self._seq_length
417
- begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
418
- end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
419
-
420
- # Shuffle valid indices
421
- num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
422
- order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
423
- begin_indices = tf.gather(begin_indices, order)
424
- end_indices = tf.gather(end_indices, order)
425
-
426
- return self._index_pair_to_mask(
427
- begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
428
-
429
- def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
430
- """Sample whole word spans as prediction targets."""
431
- min_num_words = self._params.min_num_words
432
- max_num_words = self._params.max_num_words
433
-
434
- # Note: 1.2 is the token-to-word ratio
435
- mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
436
- round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
437
-
438
- # Sample span lengths from a zipf distribution
439
- span_len_seq = np.arange(min_num_words, max_num_words + 1)
440
- probs = np.array([1.0 / (i + 1) for i in span_len_seq])
441
- probs /= np.sum(probs)
442
- logits = tf.constant(np.log(probs), dtype=tf.float32)
443
-
444
- # Sample `num_predict` words here: note that this is over sampling
445
- span_lens = tf.random.categorical(
446
- logits=logits[None],
447
- num_samples=self._max_predictions_per_seq,
448
- dtype=tf.int32,
449
- )[0] + min_num_words
450
-
451
- # Sample the ratio [0.0, 1.0) of left context lengths
452
- span_lens_float = tf.cast(span_lens, tf.float32)
453
- left_ratio = tf.random.uniform(
454
- shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
455
- left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
456
-
457
- left_ctx_len = round_to_int(left_ctx_len)
458
- right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
459
-
460
- begin_indices = (
461
- tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
462
- end_indices = begin_indices + span_lens
463
-
464
- # Remove out of range indices
465
- max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
466
- valid_idx_mask = end_indices < max_boundary_index
467
- begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
468
- end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
469
-
470
- begin_indices = tf.gather(boundary, begin_indices)
471
- end_indices = tf.gather(boundary, end_indices)
472
-
473
- # Shuffle valid indices
474
- num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
475
- order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
476
- begin_indices = tf.gather(begin_indices, order)
477
- end_indices = tf.gather(end_indices, order)
478
-
479
- return self._index_pair_to_mask(
480
- begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
481
-
482
- def _online_sample_mask(self, inputs: tf.Tensor,
483
- boundary: tf.Tensor) -> tf.Tensor:
484
- """Samples target positions for predictions.
485
-
486
- Descriptions of each strategy:
487
- - 'single_token': Samples individual tokens as prediction targets.
488
- - 'token_span': Samples spans of tokens as prediction targets.
489
- - 'whole_word': Samples individual words as prediction targets.
490
- - 'word_span': Samples spans of words as prediction targets.
491
-
492
- Args:
493
- inputs: The input tokens.
494
- boundary: The `int` Tensor of indices indicating whole word boundaries.
495
- This is used in 'whole_word' and 'word_span'
496
-
497
- Returns:
498
- The sampled `bool` input mask.
499
-
500
- Raises:
501
- `ValueError`: if `max_predictions_per_seq` is not set or if boundary is
502
- not provided for 'whole_word' and 'word_span' sample strategies.
503
- """
504
- if self._max_predictions_per_seq is None:
505
- raise ValueError('`max_predictions_per_seq` must be set.')
506
-
507
- if boundary is None and 'word' in self._sample_strategy:
508
- raise ValueError('`boundary` must be provided for {} strategy'.format(
509
- self._sample_strategy))
510
-
511
- if self._sample_strategy == 'single_token':
512
- return self._single_token_mask(inputs)
513
- elif self._sample_strategy == 'token_span':
514
- return self._token_span_mask(inputs)
515
- elif self._sample_strategy == 'whole_word':
516
- return self._whole_word_mask(inputs, boundary)
517
- elif self._sample_strategy == 'word_span':
518
- return self._word_span_mask(inputs, boundary)
519
- else:
520
- raise NotImplementedError('Invalid sample strategy.')
521
-
522
- def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
523
- """Samples a permutation of the factorization order.
524
-
525
- Args:
526
- inputs: the input tokens.
527
- input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
528
- then this means select for partial prediction.
529
-
530
- Returns:
531
- perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
532
- of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
533
- token (in original order) cannot attend to the jth attention token.
534
- target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
535
- If target_mask[i] == 1, then the i-th token needs to be predicted and
536
- the mask will be used as input. This token will be included in the loss.
537
- If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
538
- input. This token will not be included in the loss.
539
- tokens: int32 Tensor of shape [seq_length].
540
- masked_tokens: int32 Tensor of shape [seq_length].
541
- """
542
- factorization_length = tf.shape(inputs)[0]
543
- # Generate permutation indices
544
- index = tf.range(factorization_length, dtype=tf.int32)
545
- index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
546
- index = tf.random.shuffle(index)
547
- index = tf.reshape(tf.transpose(index), [-1])
548
-
549
- input_mask = tf.cast(input_mask, tf.bool)
550
-
551
- # non-functional tokens
552
- non_func_tokens = tf.logical_not(
553
- tf.logical_or(
554
- tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
555
- masked_tokens = tf.logical_and(input_mask, non_func_tokens)
556
- non_masked_or_func_tokens = tf.logical_not(masked_tokens)
557
-
558
- smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
559
-
560
- # Similar to BERT, randomly leak some masked tokens
561
- if self._leak_ratio > 0:
562
- leak_tokens = tf.logical_and(
563
- masked_tokens,
564
- tf.random.uniform([factorization_length], maxval=1.0) <
565
- self._leak_ratio)
566
- can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
567
- else:
568
- can_attend_self = non_masked_or_func_tokens
569
- to_index = tf.where(can_attend_self, smallest_index, index)
570
- from_index = tf.where(can_attend_self, to_index + 1, to_index)
571
-
572
- # For masked tokens, can attend if i > j
573
- # For context tokens, always can attend each other
574
- can_attend = from_index[:, None] > to_index[None, :]
575
-
576
- perm_mask = tf.cast(can_attend, tf.int32)
577
-
578
- # Only masked tokens are included in the loss
579
- target_mask = tf.cast(masked_tokens, tf.int32)
580
-
581
- return perm_mask, target_mask, inputs, masked_tokens
582
-
583
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
584
- """Returns a tf.dataset.Dataset."""
585
- if input_context:
586
- self._num_replicas_in_sync = input_context.num_replicas_in_sync
587
- reader = input_reader.InputReader(
588
- params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
589
- return reader.read(input_context)