Spaces:
Running
Running
Pradeep Kumar
commited on
Delete pretrain_dataloader.py
Browse files- pretrain_dataloader.py +0 -589
pretrain_dataloader.py
DELETED
@@ -1,589 +0,0 @@
|
|
1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
"""Loads dataset for the BERT pretraining task."""
|
16 |
-
import dataclasses
|
17 |
-
from typing import Mapping, Optional
|
18 |
-
|
19 |
-
from absl import logging
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import tensorflow as tf, tf_keras
|
23 |
-
from official.common import dataset_fn
|
24 |
-
from official.core import config_definitions as cfg
|
25 |
-
from official.core import input_reader
|
26 |
-
from official.nlp.data import data_loader
|
27 |
-
from official.nlp.data import data_loader_factory
|
28 |
-
|
29 |
-
|
30 |
-
@dataclasses.dataclass
|
31 |
-
class BertPretrainDataConfig(cfg.DataConfig):
|
32 |
-
"""Data config for BERT pretraining task (tasks/masked_lm)."""
|
33 |
-
input_path: str = ''
|
34 |
-
global_batch_size: int = 512
|
35 |
-
is_training: bool = True
|
36 |
-
seq_length: int = 512
|
37 |
-
max_predictions_per_seq: int = 76
|
38 |
-
use_next_sentence_label: bool = True
|
39 |
-
use_position_id: bool = False
|
40 |
-
# Historically, BERT implementations take `input_ids` and `segment_ids` as
|
41 |
-
# feature names. Inside the TF Model Garden implementation, the Keras model
|
42 |
-
# inputs are set as `input_word_ids` and `input_type_ids`. When
|
43 |
-
# v2_feature_names is True, the data loader assumes the tf.Examples use
|
44 |
-
# `input_word_ids` and `input_type_ids` as keys.
|
45 |
-
use_v2_feature_names: bool = False
|
46 |
-
file_type: str = 'tfrecord'
|
47 |
-
|
48 |
-
|
49 |
-
@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
|
50 |
-
class BertPretrainDataLoader(data_loader.DataLoader):
|
51 |
-
"""A class to load dataset for bert pretraining task."""
|
52 |
-
|
53 |
-
def __init__(self, params):
|
54 |
-
"""Inits `BertPretrainDataLoader` class.
|
55 |
-
|
56 |
-
Args:
|
57 |
-
params: A `BertPretrainDataConfig` object.
|
58 |
-
"""
|
59 |
-
self._params = params
|
60 |
-
self._seq_length = params.seq_length
|
61 |
-
self._max_predictions_per_seq = params.max_predictions_per_seq
|
62 |
-
self._use_next_sentence_label = params.use_next_sentence_label
|
63 |
-
self._use_position_id = params.use_position_id
|
64 |
-
|
65 |
-
def _name_to_features(self):
|
66 |
-
name_to_features = {
|
67 |
-
'input_mask':
|
68 |
-
tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
69 |
-
'masked_lm_positions':
|
70 |
-
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
|
71 |
-
'masked_lm_ids':
|
72 |
-
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
|
73 |
-
'masked_lm_weights':
|
74 |
-
tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
|
75 |
-
}
|
76 |
-
if self._params.use_v2_feature_names:
|
77 |
-
name_to_features.update({
|
78 |
-
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
79 |
-
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
80 |
-
})
|
81 |
-
else:
|
82 |
-
name_to_features.update({
|
83 |
-
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
84 |
-
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
85 |
-
})
|
86 |
-
if self._use_next_sentence_label:
|
87 |
-
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
|
88 |
-
tf.int64)
|
89 |
-
if self._use_position_id:
|
90 |
-
name_to_features['position_ids'] = tf.io.FixedLenFeature(
|
91 |
-
[self._seq_length], tf.int64)
|
92 |
-
return name_to_features
|
93 |
-
|
94 |
-
def _decode(self, record: tf.Tensor):
|
95 |
-
"""Decodes a serialized tf.Example."""
|
96 |
-
name_to_features = self._name_to_features()
|
97 |
-
example = tf.io.parse_single_example(record, name_to_features)
|
98 |
-
|
99 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
100 |
-
# So cast all int64 to int32.
|
101 |
-
for name in list(example.keys()):
|
102 |
-
t = example[name]
|
103 |
-
if t.dtype == tf.int64:
|
104 |
-
t = tf.cast(t, tf.int32)
|
105 |
-
example[name] = t
|
106 |
-
|
107 |
-
return example
|
108 |
-
|
109 |
-
def _parse(self, record: Mapping[str, tf.Tensor]):
|
110 |
-
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
111 |
-
x = {
|
112 |
-
'input_mask': record['input_mask'],
|
113 |
-
'masked_lm_positions': record['masked_lm_positions'],
|
114 |
-
'masked_lm_ids': record['masked_lm_ids'],
|
115 |
-
'masked_lm_weights': record['masked_lm_weights'],
|
116 |
-
}
|
117 |
-
if self._params.use_v2_feature_names:
|
118 |
-
x['input_word_ids'] = record['input_word_ids']
|
119 |
-
x['input_type_ids'] = record['input_type_ids']
|
120 |
-
else:
|
121 |
-
x['input_word_ids'] = record['input_ids']
|
122 |
-
x['input_type_ids'] = record['segment_ids']
|
123 |
-
if self._use_next_sentence_label:
|
124 |
-
x['next_sentence_labels'] = record['next_sentence_labels']
|
125 |
-
if self._use_position_id:
|
126 |
-
x['position_ids'] = record['position_ids']
|
127 |
-
|
128 |
-
return x
|
129 |
-
|
130 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
131 |
-
"""Returns a tf.dataset.Dataset."""
|
132 |
-
reader = input_reader.InputReader(
|
133 |
-
params=self._params,
|
134 |
-
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
135 |
-
decoder_fn=self._decode,
|
136 |
-
parser_fn=self._parse)
|
137 |
-
return reader.read(input_context)
|
138 |
-
|
139 |
-
|
140 |
-
@dataclasses.dataclass
|
141 |
-
class XLNetPretrainDataConfig(cfg.DataConfig):
|
142 |
-
"""Data config for XLNet pretraining task.
|
143 |
-
|
144 |
-
Attributes:
|
145 |
-
input_path: See base class.
|
146 |
-
global_batch_size: See base class.
|
147 |
-
is_training: See base class.
|
148 |
-
seq_length: The length of each sequence.
|
149 |
-
max_predictions_per_seq: The number of predictions per sequence.
|
150 |
-
reuse_length: The number of tokens in a previous segment to reuse. This
|
151 |
-
should be the same value used during pretrain data creation.
|
152 |
-
sample_strategy: The strategy used to sample factorization permutations.
|
153 |
-
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
|
154 |
-
min_num_tokens: The minimum number of tokens to sample in a span. This is
|
155 |
-
used when `sample_strategy` is 'token_span'.
|
156 |
-
max_num_tokens: The maximum number of tokens to sample in a span. This is
|
157 |
-
used when `sample_strategy` is 'token_span'.
|
158 |
-
min_num_words: The minimum number of words to sample in a span. This is used
|
159 |
-
when `sample_strategy` is 'word_span'.
|
160 |
-
max_num_words: The maximum number of words to sample in a span. This is used
|
161 |
-
when `sample_strategy` is 'word_span'.
|
162 |
-
permutation_size: The length of the longest permutation. This can be set to
|
163 |
-
`reuse_length`. This should NOT be greater than `reuse_length`, otherwise
|
164 |
-
this may introduce data leaks.
|
165 |
-
leak_ratio: The percentage of masked tokens that are leaked.
|
166 |
-
segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
|
167 |
-
segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
|
168 |
-
"""
|
169 |
-
input_path: str = ''
|
170 |
-
global_batch_size: int = 512
|
171 |
-
is_training: bool = True
|
172 |
-
seq_length: int = 512
|
173 |
-
max_predictions_per_seq: int = 76
|
174 |
-
reuse_length: int = 256
|
175 |
-
sample_strategy: str = 'word_span'
|
176 |
-
min_num_tokens: int = 1
|
177 |
-
max_num_tokens: int = 5
|
178 |
-
min_num_words: int = 1
|
179 |
-
max_num_words: int = 5
|
180 |
-
permutation_size: int = 256
|
181 |
-
leak_ratio: float = 0.1
|
182 |
-
segment_sep_id: int = 4
|
183 |
-
segment_cls_id: int = 3
|
184 |
-
|
185 |
-
|
186 |
-
@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
|
187 |
-
class XLNetPretrainDataLoader(data_loader.DataLoader):
|
188 |
-
"""A class to load dataset for xlnet pretraining task."""
|
189 |
-
|
190 |
-
def __init__(self, params: XLNetPretrainDataConfig):
|
191 |
-
"""Inits `XLNetPretrainDataLoader` class.
|
192 |
-
|
193 |
-
Args:
|
194 |
-
params: A `XLNetPretrainDataConfig` object.
|
195 |
-
"""
|
196 |
-
self._params = params
|
197 |
-
self._seq_length = params.seq_length
|
198 |
-
self._max_predictions_per_seq = params.max_predictions_per_seq
|
199 |
-
self._reuse_length = params.reuse_length
|
200 |
-
self._num_replicas_in_sync = None
|
201 |
-
self._permutation_size = params.permutation_size
|
202 |
-
self._sep_id = params.segment_sep_id
|
203 |
-
self._cls_id = params.segment_cls_id
|
204 |
-
self._sample_strategy = params.sample_strategy
|
205 |
-
self._leak_ratio = params.leak_ratio
|
206 |
-
|
207 |
-
def _decode(self, record: tf.Tensor):
|
208 |
-
"""Decodes a serialized tf.Example."""
|
209 |
-
name_to_features = {
|
210 |
-
'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
211 |
-
'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
212 |
-
'boundary_indices': tf.io.VarLenFeature(tf.int64),
|
213 |
-
}
|
214 |
-
example = tf.io.parse_single_example(record, name_to_features)
|
215 |
-
|
216 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
217 |
-
# So cast all int64 to int32.
|
218 |
-
for name in list(example.keys()):
|
219 |
-
t = example[name]
|
220 |
-
if t.dtype == tf.int64:
|
221 |
-
t = tf.cast(t, tf.int32)
|
222 |
-
example[name] = t
|
223 |
-
|
224 |
-
return example
|
225 |
-
|
226 |
-
def _parse(self, record: Mapping[str, tf.Tensor]):
|
227 |
-
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
228 |
-
x = {}
|
229 |
-
|
230 |
-
inputs = record['input_word_ids']
|
231 |
-
x['input_type_ids'] = record['input_type_ids']
|
232 |
-
|
233 |
-
if self._sample_strategy in ['whole_word', 'word_span']:
|
234 |
-
boundary = tf.sparse.to_dense(record['boundary_indices'])
|
235 |
-
else:
|
236 |
-
boundary = None
|
237 |
-
|
238 |
-
input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
|
239 |
-
|
240 |
-
if self._reuse_length > 0:
|
241 |
-
if self._permutation_size > self._reuse_length:
|
242 |
-
logging.warning(
|
243 |
-
'`permutation_size` is greater than `reuse_length` (%d > %d).'
|
244 |
-
'This may introduce data leakage.', self._permutation_size,
|
245 |
-
self._reuse_length)
|
246 |
-
|
247 |
-
# Enable the memory mechanism.
|
248 |
-
# Permute the reuse and non-reuse segments separately.
|
249 |
-
non_reuse_len = self._seq_length - self._reuse_length
|
250 |
-
if not (self._reuse_length % self._permutation_size == 0 and
|
251 |
-
non_reuse_len % self._permutation_size == 0):
|
252 |
-
raise ValueError('`reuse_length` and `seq_length` should both be '
|
253 |
-
'a multiple of `permutation_size`.')
|
254 |
-
|
255 |
-
# Creates permutation mask and target mask for the first reuse_len tokens.
|
256 |
-
# The tokens in this part are reused from the last sequence.
|
257 |
-
perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
|
258 |
-
inputs=inputs[:self._reuse_length],
|
259 |
-
input_mask=input_mask[:self._reuse_length])
|
260 |
-
|
261 |
-
# Creates permutation mask and target mask for the rest of tokens in
|
262 |
-
# current example, which are concatenation of two new segments.
|
263 |
-
perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
|
264 |
-
inputs[self._reuse_length:], input_mask[self._reuse_length:])
|
265 |
-
|
266 |
-
perm_mask_0 = tf.concat([
|
267 |
-
perm_mask_0,
|
268 |
-
tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
|
269 |
-
],
|
270 |
-
axis=1)
|
271 |
-
perm_mask_1 = tf.concat([
|
272 |
-
tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
|
273 |
-
perm_mask_1
|
274 |
-
],
|
275 |
-
axis=1)
|
276 |
-
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
|
277 |
-
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
|
278 |
-
tokens = tf.concat([tokens_0, tokens_1], axis=0)
|
279 |
-
masked_tokens = tf.concat([masked_0, masked_1], axis=0)
|
280 |
-
else:
|
281 |
-
# Disable the memory mechanism.
|
282 |
-
if self._seq_length % self._permutation_size != 0:
|
283 |
-
raise ValueError('`seq_length` should be a multiple of '
|
284 |
-
'`permutation_size`.')
|
285 |
-
# Permute the entire sequence together
|
286 |
-
perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
|
287 |
-
inputs=inputs, input_mask=input_mask)
|
288 |
-
x['permutation_mask'] = tf.reshape(perm_mask,
|
289 |
-
[self._seq_length, self._seq_length])
|
290 |
-
x['input_word_ids'] = tokens
|
291 |
-
x['masked_tokens'] = masked_tokens
|
292 |
-
|
293 |
-
target = tokens
|
294 |
-
if self._max_predictions_per_seq is not None:
|
295 |
-
indices = tf.range(self._seq_length, dtype=tf.int32)
|
296 |
-
bool_target_mask = tf.cast(target_mask, tf.bool)
|
297 |
-
indices = tf.boolean_mask(indices, bool_target_mask)
|
298 |
-
|
299 |
-
# account for extra padding due to CLS/SEP.
|
300 |
-
actual_num_predict = tf.shape(indices)[0]
|
301 |
-
pad_len = self._max_predictions_per_seq - actual_num_predict
|
302 |
-
|
303 |
-
target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
|
304 |
-
paddings = tf.zeros([pad_len, self._seq_length],
|
305 |
-
dtype=target_mapping.dtype)
|
306 |
-
target_mapping = tf.concat([target_mapping, paddings], axis=0)
|
307 |
-
x['target_mapping'] = tf.reshape(
|
308 |
-
target_mapping, [self._max_predictions_per_seq, self._seq_length])
|
309 |
-
|
310 |
-
target = tf.boolean_mask(target, bool_target_mask)
|
311 |
-
paddings = tf.zeros([pad_len], dtype=target.dtype)
|
312 |
-
target = tf.concat([target, paddings], axis=0)
|
313 |
-
x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
|
314 |
-
|
315 |
-
target_mask = tf.concat([
|
316 |
-
tf.ones([actual_num_predict], dtype=tf.int32),
|
317 |
-
tf.zeros([pad_len], dtype=tf.int32)
|
318 |
-
],
|
319 |
-
axis=0)
|
320 |
-
x['target_mask'] = tf.reshape(target_mask,
|
321 |
-
[self._max_predictions_per_seq])
|
322 |
-
else:
|
323 |
-
x['target'] = tf.reshape(target, [self._seq_length])
|
324 |
-
x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
|
325 |
-
return x
|
326 |
-
|
327 |
-
def _index_pair_to_mask(self, begin_indices: tf.Tensor,
|
328 |
-
end_indices: tf.Tensor,
|
329 |
-
inputs: tf.Tensor) -> tf.Tensor:
|
330 |
-
"""Converts beginning and end indices into an actual mask."""
|
331 |
-
non_func_mask = tf.logical_and(
|
332 |
-
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
|
333 |
-
all_indices = tf.where(
|
334 |
-
non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
|
335 |
-
tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
|
336 |
-
candidate_matrix = tf.cast(
|
337 |
-
tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
|
338 |
-
all_indices[None, :] < end_indices[:, None]), tf.float32)
|
339 |
-
cumsum_matrix = tf.reshape(
|
340 |
-
tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
|
341 |
-
masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
|
342 |
-
tf.float32)
|
343 |
-
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
|
344 |
-
return tf.cast(target_mask, tf.bool)
|
345 |
-
|
346 |
-
def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
|
347 |
-
"""Samples individual tokens as prediction targets."""
|
348 |
-
all_indices = tf.range(self._seq_length, dtype=tf.int32)
|
349 |
-
non_func_mask = tf.logical_and(
|
350 |
-
tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
|
351 |
-
non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
|
352 |
-
|
353 |
-
masked_pos = tf.random.shuffle(non_func_indices)
|
354 |
-
masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
|
355 |
-
|
356 |
-
sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
|
357 |
-
sparse_indices = tf.cast(sparse_indices, tf.int64)
|
358 |
-
|
359 |
-
sparse_indices = tf.sparse.SparseTensor(
|
360 |
-
sparse_indices,
|
361 |
-
values=tf.ones_like(masked_pos),
|
362 |
-
dense_shape=(1, self._seq_length))
|
363 |
-
|
364 |
-
target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)
|
365 |
-
|
366 |
-
return tf.squeeze(tf.cast(target_mask, tf.bool))
|
367 |
-
|
368 |
-
def _whole_word_mask(self, inputs: tf.Tensor,
|
369 |
-
boundary: tf.Tensor) -> tf.Tensor:
|
370 |
-
"""Samples whole words as prediction targets."""
|
371 |
-
pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
|
372 |
-
cand_pair_indices = tf.random.shuffle(
|
373 |
-
pair_indices)[:self._max_predictions_per_seq]
|
374 |
-
begin_indices = cand_pair_indices[:, 0]
|
375 |
-
end_indices = cand_pair_indices[:, 1]
|
376 |
-
|
377 |
-
return self._index_pair_to_mask(
|
378 |
-
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
379 |
-
|
380 |
-
def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
|
381 |
-
"""Samples token spans as prediction targets."""
|
382 |
-
min_num_tokens = self._params.min_num_tokens
|
383 |
-
max_num_tokens = self._params.max_num_tokens
|
384 |
-
|
385 |
-
mask_alpha = self._seq_length / self._max_predictions_per_seq
|
386 |
-
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
|
387 |
-
|
388 |
-
# Sample span lengths from a zipf distribution
|
389 |
-
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
|
390 |
-
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
|
391 |
-
|
392 |
-
probs /= np.sum(probs)
|
393 |
-
logits = tf.constant(np.log(probs), dtype=tf.float32)
|
394 |
-
span_lens = tf.random.categorical(
|
395 |
-
logits=logits[None],
|
396 |
-
num_samples=self._max_predictions_per_seq,
|
397 |
-
dtype=tf.int32,
|
398 |
-
)[0] + min_num_tokens
|
399 |
-
|
400 |
-
# Sample the ratio [0.0, 1.0) of left context lengths
|
401 |
-
span_lens_float = tf.cast(span_lens, tf.float32)
|
402 |
-
left_ratio = tf.random.uniform(
|
403 |
-
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
|
404 |
-
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
|
405 |
-
left_ctx_len = round_to_int(left_ctx_len)
|
406 |
-
|
407 |
-
# Compute the offset from left start to the right end
|
408 |
-
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
|
409 |
-
|
410 |
-
# Get the actual begin and end indices
|
411 |
-
begin_indices = (
|
412 |
-
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
|
413 |
-
end_indices = begin_indices + span_lens
|
414 |
-
|
415 |
-
# Remove out of range indices
|
416 |
-
valid_idx_mask = end_indices < self._seq_length
|
417 |
-
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
|
418 |
-
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
|
419 |
-
|
420 |
-
# Shuffle valid indices
|
421 |
-
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
|
422 |
-
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
|
423 |
-
begin_indices = tf.gather(begin_indices, order)
|
424 |
-
end_indices = tf.gather(end_indices, order)
|
425 |
-
|
426 |
-
return self._index_pair_to_mask(
|
427 |
-
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
428 |
-
|
429 |
-
def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
|
430 |
-
"""Sample whole word spans as prediction targets."""
|
431 |
-
min_num_words = self._params.min_num_words
|
432 |
-
max_num_words = self._params.max_num_words
|
433 |
-
|
434 |
-
# Note: 1.2 is the token-to-word ratio
|
435 |
-
mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
|
436 |
-
round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
|
437 |
-
|
438 |
-
# Sample span lengths from a zipf distribution
|
439 |
-
span_len_seq = np.arange(min_num_words, max_num_words + 1)
|
440 |
-
probs = np.array([1.0 / (i + 1) for i in span_len_seq])
|
441 |
-
probs /= np.sum(probs)
|
442 |
-
logits = tf.constant(np.log(probs), dtype=tf.float32)
|
443 |
-
|
444 |
-
# Sample `num_predict` words here: note that this is over sampling
|
445 |
-
span_lens = tf.random.categorical(
|
446 |
-
logits=logits[None],
|
447 |
-
num_samples=self._max_predictions_per_seq,
|
448 |
-
dtype=tf.int32,
|
449 |
-
)[0] + min_num_words
|
450 |
-
|
451 |
-
# Sample the ratio [0.0, 1.0) of left context lengths
|
452 |
-
span_lens_float = tf.cast(span_lens, tf.float32)
|
453 |
-
left_ratio = tf.random.uniform(
|
454 |
-
shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
|
455 |
-
left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
|
456 |
-
|
457 |
-
left_ctx_len = round_to_int(left_ctx_len)
|
458 |
-
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
|
459 |
-
|
460 |
-
begin_indices = (
|
461 |
-
tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
|
462 |
-
end_indices = begin_indices + span_lens
|
463 |
-
|
464 |
-
# Remove out of range indices
|
465 |
-
max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
|
466 |
-
valid_idx_mask = end_indices < max_boundary_index
|
467 |
-
begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
|
468 |
-
end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
|
469 |
-
|
470 |
-
begin_indices = tf.gather(boundary, begin_indices)
|
471 |
-
end_indices = tf.gather(boundary, end_indices)
|
472 |
-
|
473 |
-
# Shuffle valid indices
|
474 |
-
num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
|
475 |
-
order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
|
476 |
-
begin_indices = tf.gather(begin_indices, order)
|
477 |
-
end_indices = tf.gather(end_indices, order)
|
478 |
-
|
479 |
-
return self._index_pair_to_mask(
|
480 |
-
begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
|
481 |
-
|
482 |
-
def _online_sample_mask(self, inputs: tf.Tensor,
|
483 |
-
boundary: tf.Tensor) -> tf.Tensor:
|
484 |
-
"""Samples target positions for predictions.
|
485 |
-
|
486 |
-
Descriptions of each strategy:
|
487 |
-
- 'single_token': Samples individual tokens as prediction targets.
|
488 |
-
- 'token_span': Samples spans of tokens as prediction targets.
|
489 |
-
- 'whole_word': Samples individual words as prediction targets.
|
490 |
-
- 'word_span': Samples spans of words as prediction targets.
|
491 |
-
|
492 |
-
Args:
|
493 |
-
inputs: The input tokens.
|
494 |
-
boundary: The `int` Tensor of indices indicating whole word boundaries.
|
495 |
-
This is used in 'whole_word' and 'word_span'
|
496 |
-
|
497 |
-
Returns:
|
498 |
-
The sampled `bool` input mask.
|
499 |
-
|
500 |
-
Raises:
|
501 |
-
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is
|
502 |
-
not provided for 'whole_word' and 'word_span' sample strategies.
|
503 |
-
"""
|
504 |
-
if self._max_predictions_per_seq is None:
|
505 |
-
raise ValueError('`max_predictions_per_seq` must be set.')
|
506 |
-
|
507 |
-
if boundary is None and 'word' in self._sample_strategy:
|
508 |
-
raise ValueError('`boundary` must be provided for {} strategy'.format(
|
509 |
-
self._sample_strategy))
|
510 |
-
|
511 |
-
if self._sample_strategy == 'single_token':
|
512 |
-
return self._single_token_mask(inputs)
|
513 |
-
elif self._sample_strategy == 'token_span':
|
514 |
-
return self._token_span_mask(inputs)
|
515 |
-
elif self._sample_strategy == 'whole_word':
|
516 |
-
return self._whole_word_mask(inputs, boundary)
|
517 |
-
elif self._sample_strategy == 'word_span':
|
518 |
-
return self._word_span_mask(inputs, boundary)
|
519 |
-
else:
|
520 |
-
raise NotImplementedError('Invalid sample strategy.')
|
521 |
-
|
522 |
-
def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
|
523 |
-
"""Samples a permutation of the factorization order.
|
524 |
-
|
525 |
-
Args:
|
526 |
-
inputs: the input tokens.
|
527 |
-
input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
|
528 |
-
then this means select for partial prediction.
|
529 |
-
|
530 |
-
Returns:
|
531 |
-
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
|
532 |
-
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
|
533 |
-
token (in original order) cannot attend to the jth attention token.
|
534 |
-
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
|
535 |
-
If target_mask[i] == 1, then the i-th token needs to be predicted and
|
536 |
-
the mask will be used as input. This token will be included in the loss.
|
537 |
-
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
|
538 |
-
input. This token will not be included in the loss.
|
539 |
-
tokens: int32 Tensor of shape [seq_length].
|
540 |
-
masked_tokens: int32 Tensor of shape [seq_length].
|
541 |
-
"""
|
542 |
-
factorization_length = tf.shape(inputs)[0]
|
543 |
-
# Generate permutation indices
|
544 |
-
index = tf.range(factorization_length, dtype=tf.int32)
|
545 |
-
index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
|
546 |
-
index = tf.random.shuffle(index)
|
547 |
-
index = tf.reshape(tf.transpose(index), [-1])
|
548 |
-
|
549 |
-
input_mask = tf.cast(input_mask, tf.bool)
|
550 |
-
|
551 |
-
# non-functional tokens
|
552 |
-
non_func_tokens = tf.logical_not(
|
553 |
-
tf.logical_or(
|
554 |
-
tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
|
555 |
-
masked_tokens = tf.logical_and(input_mask, non_func_tokens)
|
556 |
-
non_masked_or_func_tokens = tf.logical_not(masked_tokens)
|
557 |
-
|
558 |
-
smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
|
559 |
-
|
560 |
-
# Similar to BERT, randomly leak some masked tokens
|
561 |
-
if self._leak_ratio > 0:
|
562 |
-
leak_tokens = tf.logical_and(
|
563 |
-
masked_tokens,
|
564 |
-
tf.random.uniform([factorization_length], maxval=1.0) <
|
565 |
-
self._leak_ratio)
|
566 |
-
can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
|
567 |
-
else:
|
568 |
-
can_attend_self = non_masked_or_func_tokens
|
569 |
-
to_index = tf.where(can_attend_self, smallest_index, index)
|
570 |
-
from_index = tf.where(can_attend_self, to_index + 1, to_index)
|
571 |
-
|
572 |
-
# For masked tokens, can attend if i > j
|
573 |
-
# For context tokens, always can attend each other
|
574 |
-
can_attend = from_index[:, None] > to_index[None, :]
|
575 |
-
|
576 |
-
perm_mask = tf.cast(can_attend, tf.int32)
|
577 |
-
|
578 |
-
# Only masked tokens are included in the loss
|
579 |
-
target_mask = tf.cast(masked_tokens, tf.int32)
|
580 |
-
|
581 |
-
return perm_mask, target_mask, inputs, masked_tokens
|
582 |
-
|
583 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
584 |
-
"""Returns a tf.dataset.Dataset."""
|
585 |
-
if input_context:
|
586 |
-
self._num_replicas_in_sync = input_context.num_replicas_in_sync
|
587 |
-
reader = input_reader.InputReader(
|
588 |
-
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
|
589 |
-
return reader.read(input_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|