Pradeep Kumar commited on
Commit
f3dac86
·
verified ·
1 Parent(s): 20e1eb7

Delete create_pretraining_data.py

Browse files
Files changed (1) hide show
  1. create_pretraining_data.py +0 -718
create_pretraining_data.py DELETED
@@ -1,718 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Create masked LM/next sentence masked_lm TF examples for BERT."""
16
-
17
- import collections
18
- import itertools
19
- import random
20
-
21
- # Import libraries
22
-
23
- from absl import app
24
- from absl import flags
25
- from absl import logging
26
- import tensorflow as tf, tf_keras
27
-
28
- from official.nlp.tools import tokenization
29
-
30
- FLAGS = flags.FLAGS
31
-
32
- flags.DEFINE_string("input_file", None,
33
- "Input raw text file (or comma-separated list of files).")
34
-
35
- flags.DEFINE_string(
36
- "output_file", None,
37
- "Output TF example file (or comma-separated list of files).")
38
-
39
- flags.DEFINE_enum(
40
- "tokenization",
41
- "WordPiece",
42
- ["WordPiece", "SentencePiece"],
43
- "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
44
- "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
45
- "while ALBERT uses SentencePiece tokenizer.",
46
- )
47
-
48
- flags.DEFINE_string(
49
- "vocab_file",
50
- None,
51
- "For WordPiece tokenization, the vocabulary file of the tokenizer.",
52
- )
53
-
54
- flags.DEFINE_string(
55
- "sp_model_file",
56
- "",
57
- "For SentencePiece tokenization, the path to the model of the tokenizer.",
58
- )
59
-
60
- flags.DEFINE_bool(
61
- "do_lower_case", True,
62
- "Whether to lower case the input text. Should be True for uncased "
63
- "models and False for cased models.")
64
-
65
- flags.DEFINE_bool(
66
- "do_whole_word_mask",
67
- False,
68
- "Whether to use whole word masking rather than per-token masking.",
69
- )
70
-
71
- flags.DEFINE_integer(
72
- "max_ngram_size", None,
73
- "Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
74
- "weighting scheme to favor shorter n-grams. "
75
- "Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
76
-
77
- flags.DEFINE_bool(
78
- "gzip_compress", False,
79
- "Whether to use `GZIP` compress option to get compressed TFRecord files.")
80
-
81
- flags.DEFINE_bool(
82
- "use_v2_feature_names", False,
83
- "Whether to use the feature names consistent with the models.")
84
-
85
- flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
86
-
87
- flags.DEFINE_integer("max_predictions_per_seq", 20,
88
- "Maximum number of masked LM predictions per sequence.")
89
-
90
- flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
91
-
92
- flags.DEFINE_integer(
93
- "dupe_factor", 10,
94
- "Number of times to duplicate the input data (with different masks).")
95
-
96
- flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
97
-
98
- flags.DEFINE_float(
99
- "short_seq_prob", 0.1,
100
- "Probability of creating sequences which are shorter than the "
101
- "maximum length.")
102
-
103
-
104
- class TrainingInstance(object):
105
- """A single training instance (sentence pair)."""
106
-
107
- def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
108
- is_random_next):
109
- self.tokens = tokens
110
- self.segment_ids = segment_ids
111
- self.is_random_next = is_random_next
112
- self.masked_lm_positions = masked_lm_positions
113
- self.masked_lm_labels = masked_lm_labels
114
-
115
- def __str__(self):
116
- s = ""
117
- s += "tokens: %s\n" % (" ".join(
118
- [tokenization.printable_text(x) for x in self.tokens]))
119
- s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
120
- s += "is_random_next: %s\n" % self.is_random_next
121
- s += "masked_lm_positions: %s\n" % (" ".join(
122
- [str(x) for x in self.masked_lm_positions]))
123
- s += "masked_lm_labels: %s\n" % (" ".join(
124
- [tokenization.printable_text(x) for x in self.masked_lm_labels]))
125
- s += "\n"
126
- return s
127
-
128
- def __repr__(self):
129
- return self.__str__()
130
-
131
-
132
- def write_instance_to_example_files(instances, tokenizer, max_seq_length,
133
- max_predictions_per_seq, output_files,
134
- gzip_compress, use_v2_feature_names):
135
- """Creates TF example files from `TrainingInstance`s."""
136
- writers = []
137
- for output_file in output_files:
138
- writers.append(
139
- tf.io.TFRecordWriter(
140
- output_file, options="GZIP" if gzip_compress else ""))
141
-
142
- writer_index = 0
143
-
144
- total_written = 0
145
- for (inst_index, instance) in enumerate(instances):
146
- input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
147
- input_mask = [1] * len(input_ids)
148
- segment_ids = list(instance.segment_ids)
149
- assert len(input_ids) <= max_seq_length
150
-
151
- while len(input_ids) < max_seq_length:
152
- input_ids.append(0)
153
- input_mask.append(0)
154
- segment_ids.append(0)
155
-
156
- assert len(input_ids) == max_seq_length
157
- assert len(input_mask) == max_seq_length
158
- assert len(segment_ids) == max_seq_length
159
-
160
- masked_lm_positions = list(instance.masked_lm_positions)
161
- masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
162
- masked_lm_weights = [1.0] * len(masked_lm_ids)
163
-
164
- while len(masked_lm_positions) < max_predictions_per_seq:
165
- masked_lm_positions.append(0)
166
- masked_lm_ids.append(0)
167
- masked_lm_weights.append(0.0)
168
-
169
- next_sentence_label = 1 if instance.is_random_next else 0
170
-
171
- features = collections.OrderedDict()
172
- if use_v2_feature_names:
173
- features["input_word_ids"] = create_int_feature(input_ids)
174
- features["input_type_ids"] = create_int_feature(segment_ids)
175
- else:
176
- features["input_ids"] = create_int_feature(input_ids)
177
- features["segment_ids"] = create_int_feature(segment_ids)
178
-
179
- features["input_mask"] = create_int_feature(input_mask)
180
- features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
181
- features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
182
- features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
183
- features["next_sentence_labels"] = create_int_feature([next_sentence_label])
184
-
185
- tf_example = tf.train.Example(features=tf.train.Features(feature=features))
186
-
187
- writers[writer_index].write(tf_example.SerializeToString())
188
- writer_index = (writer_index + 1) % len(writers)
189
-
190
- total_written += 1
191
-
192
- if inst_index < 20:
193
- logging.info("*** Example ***")
194
- logging.info("tokens: %s", " ".join(
195
- [tokenization.printable_text(x) for x in instance.tokens]))
196
-
197
- for feature_name in features.keys():
198
- feature = features[feature_name]
199
- values = []
200
- if feature.int64_list.value:
201
- values = feature.int64_list.value
202
- elif feature.float_list.value:
203
- values = feature.float_list.value
204
- logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
205
-
206
- for writer in writers:
207
- writer.close()
208
-
209
- logging.info("Wrote %d total instances", total_written)
210
-
211
-
212
- def create_int_feature(values):
213
- feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
214
- return feature
215
-
216
-
217
- def create_float_feature(values):
218
- feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
219
- return feature
220
-
221
-
222
- def create_training_instances(
223
- input_files,
224
- tokenizer,
225
- processor_text_fn,
226
- max_seq_length,
227
- dupe_factor,
228
- short_seq_prob,
229
- masked_lm_prob,
230
- max_predictions_per_seq,
231
- rng,
232
- do_whole_word_mask=False,
233
- max_ngram_size=None,
234
- ):
235
- """Create `TrainingInstance`s from raw text."""
236
- all_documents = [[]]
237
-
238
- # Input file format:
239
- # (1) One sentence per line. These should ideally be actual sentences, not
240
- # entire paragraphs or arbitrary spans of text. (Because we use the
241
- # sentence boundaries for the "next sentence prediction" task).
242
- # (2) Blank lines between documents. Document boundaries are needed so
243
- # that the "next sentence prediction" task doesn't span between documents.
244
- for input_file in input_files:
245
- with tf.io.gfile.GFile(input_file, "rb") as reader:
246
- for line in reader:
247
- line = processor_text_fn(line)
248
-
249
- # Empty lines are used as document delimiters
250
- if not line:
251
- all_documents.append([])
252
- tokens = tokenizer.tokenize(line)
253
- if tokens:
254
- all_documents[-1].append(tokens)
255
-
256
- # Remove empty documents
257
- all_documents = [x for x in all_documents if x]
258
- rng.shuffle(all_documents)
259
-
260
- vocab_words = list(tokenizer.vocab.keys())
261
- instances = []
262
- for _ in range(dupe_factor):
263
- for document_index in range(len(all_documents)):
264
- instances.extend(
265
- create_instances_from_document(
266
- all_documents, document_index, max_seq_length, short_seq_prob,
267
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
268
- do_whole_word_mask, max_ngram_size))
269
-
270
- rng.shuffle(instances)
271
- return instances
272
-
273
-
274
- def create_instances_from_document(
275
- all_documents, document_index, max_seq_length, short_seq_prob,
276
- masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
277
- do_whole_word_mask=False,
278
- max_ngram_size=None):
279
- """Creates `TrainingInstance`s for a single document."""
280
- document = all_documents[document_index]
281
-
282
- # Account for [CLS], [SEP], [SEP]
283
- max_num_tokens = max_seq_length - 3
284
-
285
- # We *usually* want to fill up the entire sequence since we are padding
286
- # to `max_seq_length` anyways, so short sequences are generally wasted
287
- # computation. However, we *sometimes*
288
- # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
289
- # sequences to minimize the mismatch between pre-training and fine-tuning.
290
- # The `target_seq_length` is just a rough target however, whereas
291
- # `max_seq_length` is a hard limit.
292
- target_seq_length = max_num_tokens
293
- if rng.random() < short_seq_prob:
294
- target_seq_length = rng.randint(2, max_num_tokens)
295
-
296
- # We DON'T just concatenate all of the tokens from a document into a long
297
- # sequence and choose an arbitrary split point because this would make the
298
- # next sentence prediction task too easy. Instead, we split the input into
299
- # segments "A" and "B" based on the actual "sentences" provided by the user
300
- # input.
301
- instances = []
302
- current_chunk = []
303
- current_length = 0
304
- i = 0
305
- while i < len(document):
306
- segment = document[i]
307
- current_chunk.append(segment)
308
- current_length += len(segment)
309
- if i == len(document) - 1 or current_length >= target_seq_length:
310
- if current_chunk:
311
- # `a_end` is how many segments from `current_chunk` go into the `A`
312
- # (first) sentence.
313
- a_end = 1
314
- if len(current_chunk) >= 2:
315
- a_end = rng.randint(1, len(current_chunk) - 1)
316
-
317
- tokens_a = []
318
- for j in range(a_end):
319
- tokens_a.extend(current_chunk[j])
320
-
321
- tokens_b = []
322
- # Random next
323
- is_random_next = False
324
- if len(current_chunk) == 1 or rng.random() < 0.5:
325
- is_random_next = True
326
- target_b_length = target_seq_length - len(tokens_a)
327
-
328
- # This should rarely go for more than one iteration for large
329
- # corpora. However, just to be careful, we try to make sure that
330
- # the random document is not the same as the document
331
- # we're processing.
332
- for _ in range(10):
333
- random_document_index = rng.randint(0, len(all_documents) - 1)
334
- if random_document_index != document_index:
335
- break
336
-
337
- random_document = all_documents[random_document_index]
338
- random_start = rng.randint(0, len(random_document) - 1)
339
- for j in range(random_start, len(random_document)):
340
- tokens_b.extend(random_document[j])
341
- if len(tokens_b) >= target_b_length:
342
- break
343
- # We didn't actually use these segments so we "put them back" so
344
- # they don't go to waste.
345
- num_unused_segments = len(current_chunk) - a_end
346
- i -= num_unused_segments
347
- # Actual next
348
- else:
349
- is_random_next = False
350
- for j in range(a_end, len(current_chunk)):
351
- tokens_b.extend(current_chunk[j])
352
- truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
353
-
354
- assert len(tokens_a) >= 1
355
- assert len(tokens_b) >= 1
356
-
357
- tokens = []
358
- segment_ids = []
359
- tokens.append("[CLS]")
360
- segment_ids.append(0)
361
- for token in tokens_a:
362
- tokens.append(token)
363
- segment_ids.append(0)
364
-
365
- tokens.append("[SEP]")
366
- segment_ids.append(0)
367
-
368
- for token in tokens_b:
369
- tokens.append(token)
370
- segment_ids.append(1)
371
- tokens.append("[SEP]")
372
- segment_ids.append(1)
373
-
374
- (tokens, masked_lm_positions,
375
- masked_lm_labels) = create_masked_lm_predictions(
376
- tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
377
- do_whole_word_mask, max_ngram_size)
378
- instance = TrainingInstance(
379
- tokens=tokens,
380
- segment_ids=segment_ids,
381
- is_random_next=is_random_next,
382
- masked_lm_positions=masked_lm_positions,
383
- masked_lm_labels=masked_lm_labels)
384
- instances.append(instance)
385
- current_chunk = []
386
- current_length = 0
387
- i += 1
388
-
389
- return instances
390
-
391
-
392
- MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
393
- ["index", "label"])
394
-
395
- # A _Gram is a [half-open) interval of token indices which form a word.
396
- # E.g.,
397
- # words: ["The", "doghouse"]
398
- # tokens: ["The", "dog", "##house"]
399
- # grams: [(0,1), (1,3)]
400
- _Gram = collections.namedtuple("_Gram", ["begin", "end"])
401
-
402
-
403
- def _window(iterable, size):
404
- """Helper to create a sliding window iterator with a given size.
405
-
406
- E.g.,
407
- input = [1, 2, 3, 4]
408
- _window(input, 1) => [1], [2], [3], [4]
409
- _window(input, 2) => [1, 2], [2, 3], [3, 4]
410
- _window(input, 3) => [1, 2, 3], [2, 3, 4]
411
- _window(input, 4) => [1, 2, 3, 4]
412
- _window(input, 5) => None
413
-
414
- Args:
415
- iterable: elements to iterate over.
416
- size: size of the window.
417
-
418
- Yields:
419
- Elements of `iterable` batched into a sliding window of length `size`.
420
- """
421
- i = iter(iterable)
422
- window = []
423
- try:
424
- for e in range(0, size):
425
- window.append(next(i))
426
- yield window
427
- except StopIteration:
428
- # handle the case where iterable's length is less than the window size.
429
- return
430
- for e in i:
431
- window = window[1:] + [e]
432
- yield window
433
-
434
-
435
- def _contiguous(sorted_grams):
436
- """Test whether a sequence of grams is contiguous.
437
-
438
- Args:
439
- sorted_grams: _Grams which are sorted in increasing order.
440
- Returns:
441
- True if `sorted_grams` are touching each other.
442
-
443
- E.g.,
444
- _contiguous([(1, 4), (4, 5), (5, 10)]) == True
445
- _contiguous([(1, 2), (4, 5)]) == False
446
- """
447
- for a, b in _window(sorted_grams, 2):
448
- if a.end != b.begin:
449
- return False
450
- return True
451
-
452
-
453
- def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
454
- """Create a list of masking {1, ..., n}-grams from a list of one-grams.
455
-
456
- This is an extension of 'whole word masking' to mask multiple, contiguous
457
- words such as (e.g., "the red boat").
458
-
459
- Each input gram represents the token indices of a single word,
460
- words: ["the", "red", "boat"]
461
- tokens: ["the", "red", "boa", "##t"]
462
- grams: [(0,1), (1,2), (2,4)]
463
-
464
- For a `max_ngram_size` of three, possible outputs masks include:
465
- 1-grams: (0,1), (1,2), (2,4)
466
- 2-grams: (0,2), (1,4)
467
- 3-grams; (0,4)
468
-
469
- Output masks will not overlap and contain less than `max_masked_tokens` total
470
- tokens. E.g., for the example above with `max_masked_tokens` as three,
471
- valid outputs are,
472
- [(0,1), (1,2)] # "the", "red" covering two tokens
473
- [(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
474
-
475
- The length of the selected n-gram follows a zipf weighting to
476
- favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
477
-
478
- Args:
479
- grams: List of one-grams.
480
- max_ngram_size: Maximum number of contiguous one-grams combined to create
481
- an n-gram.
482
- max_masked_tokens: Maximum total number of tokens to be masked.
483
- rng: `random.Random` generator.
484
-
485
- Returns:
486
- A list of n-grams to be used as masks.
487
- """
488
- if not grams:
489
- return None
490
-
491
- grams = sorted(grams)
492
- num_tokens = grams[-1].end
493
-
494
- # Ensure our grams are valid (i.e., they don't overlap).
495
- for a, b in _window(grams, 2):
496
- if a.end > b.begin:
497
- raise ValueError("overlapping grams: {}".format(grams))
498
-
499
- # Build map from n-gram length to list of n-grams.
500
- ngrams = {i: [] for i in range(1, max_ngram_size+1)}
501
- for gram_size in range(1, max_ngram_size+1):
502
- for g in _window(grams, gram_size):
503
- if _contiguous(g):
504
- # Add an n-gram which spans these one-grams.
505
- ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
506
-
507
- # Shuffle each list of n-grams.
508
- for v in ngrams.values():
509
- rng.shuffle(v)
510
-
511
- # Create the weighting for n-gram length selection.
512
- # Stored cumulatively for `random.choices` below.
513
- cummulative_weights = list(
514
- itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
515
-
516
- output_ngrams = []
517
- # Keep a bitmask of which tokens have been masked.
518
- masked_tokens = [False] * num_tokens
519
- # Loop until we have enough masked tokens or there are no more candidate
520
- # n-grams of any length.
521
- # Each code path should ensure one or more elements from `ngrams` are removed
522
- # to guarantee this loop terminates.
523
- while (sum(masked_tokens) < max_masked_tokens and
524
- sum(len(s) for s in ngrams.values())):
525
- # Pick an n-gram size based on our weights.
526
- sz = random.choices(range(1, max_ngram_size+1),
527
- cum_weights=cummulative_weights)[0]
528
-
529
- # Ensure this size doesn't result in too many masked tokens.
530
- # E.g., a two-gram contains _at least_ two tokens.
531
- if sum(masked_tokens) + sz > max_masked_tokens:
532
- # All n-grams of this length are too long and can be removed from
533
- # consideration.
534
- ngrams[sz].clear()
535
- continue
536
-
537
- # All of the n-grams of this size have been used.
538
- if not ngrams[sz]:
539
- continue
540
-
541
- # Choose a random n-gram of the given size.
542
- gram = ngrams[sz].pop()
543
- num_gram_tokens = gram.end-gram.begin
544
-
545
- # Check if this would add too many tokens.
546
- if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
547
- continue
548
-
549
- # Check if any of the tokens in this gram have already been masked.
550
- if sum(masked_tokens[gram.begin:gram.end]):
551
- continue
552
-
553
- # Found a usable n-gram! Mark its tokens as masked and add it to return.
554
- masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
555
- output_ngrams.append(gram)
556
- return output_ngrams
557
-
558
-
559
- def _tokens_to_grams(tokens):
560
- """Reconstitue grams (words) from `tokens`.
561
-
562
- E.g.,
563
- tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
564
- grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
565
-
566
- Args:
567
- tokens: list of tokens (word pieces or sentence pieces).
568
-
569
- Returns:
570
- List of _Grams representing spans of whole words
571
- (without "[CLS]" and "[SEP]").
572
- """
573
- grams = []
574
- gram_start_pos = None
575
- for i, token in enumerate(tokens):
576
- if gram_start_pos is not None and token.startswith("##"):
577
- continue
578
- if gram_start_pos is not None:
579
- grams.append(_Gram(gram_start_pos, i))
580
- if token not in ["[CLS]", "[SEP]"]:
581
- gram_start_pos = i
582
- else:
583
- gram_start_pos = None
584
- if gram_start_pos is not None:
585
- grams.append(_Gram(gram_start_pos, len(tokens)))
586
- return grams
587
-
588
-
589
- def create_masked_lm_predictions(tokens, masked_lm_prob,
590
- max_predictions_per_seq, vocab_words, rng,
591
- do_whole_word_mask,
592
- max_ngram_size=None):
593
- """Creates the predictions for the masked LM objective."""
594
- if do_whole_word_mask:
595
- grams = _tokens_to_grams(tokens)
596
- else:
597
- # Here we consider each token to be a word to allow for sub-word masking.
598
- if max_ngram_size:
599
- raise ValueError("cannot use ngram masking without whole word masking")
600
- grams = [_Gram(i, i+1) for i in range(0, len(tokens))
601
- if tokens[i] not in ["[CLS]", "[SEP]"]]
602
-
603
- num_to_predict = min(max_predictions_per_seq,
604
- max(1, int(round(len(tokens) * masked_lm_prob))))
605
- # Generate masks. If `max_ngram_size` in [0, None] it means we're doing
606
- # whole word masking or token level masking. Both of these can be treated
607
- # as the `max_ngram_size=1` case.
608
- masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
609
- num_to_predict, rng)
610
- masked_lms = []
611
- output_tokens = list(tokens)
612
- for gram in masked_grams:
613
- # 80% of the time, replace all n-gram tokens with [MASK]
614
- if rng.random() < 0.8:
615
- replacement_action = lambda idx: "[MASK]"
616
- else:
617
- # 10% of the time, keep all the original n-gram tokens.
618
- if rng.random() < 0.5:
619
- replacement_action = lambda idx: tokens[idx]
620
- # 10% of the time, replace each n-gram token with a random word.
621
- else:
622
- replacement_action = lambda idx: rng.choice(vocab_words)
623
-
624
- for idx in range(gram.begin, gram.end):
625
- output_tokens[idx] = replacement_action(idx)
626
- masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
627
-
628
- assert len(masked_lms) <= num_to_predict
629
- masked_lms = sorted(masked_lms, key=lambda x: x.index)
630
-
631
- masked_lm_positions = []
632
- masked_lm_labels = []
633
- for p in masked_lms:
634
- masked_lm_positions.append(p.index)
635
- masked_lm_labels.append(p.label)
636
-
637
- return (output_tokens, masked_lm_positions, masked_lm_labels)
638
-
639
-
640
- def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
641
- """Truncates a pair of sequences to a maximum sequence length."""
642
- while True:
643
- total_length = len(tokens_a) + len(tokens_b)
644
- if total_length <= max_num_tokens:
645
- break
646
-
647
- trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
648
- assert len(trunc_tokens) >= 1
649
-
650
- # We want to sometimes truncate from the front and sometimes from the
651
- # back to add more randomness and avoid biases.
652
- if rng.random() < 0.5:
653
- del trunc_tokens[0]
654
- else:
655
- trunc_tokens.pop()
656
-
657
-
658
- def get_processor_text_fn(is_sentence_piece, do_lower_case):
659
- def processor_text_fn(text):
660
- text = tokenization.convert_to_unicode(text)
661
- if is_sentence_piece:
662
- # Additional preprocessing specific to the SentencePiece tokenizer.
663
- text = tokenization.preprocess_text(text, lower=do_lower_case)
664
-
665
- return text.strip()
666
-
667
- return processor_text_fn
668
-
669
-
670
- def main(_):
671
- if FLAGS.tokenization == "WordPiece":
672
- tokenizer = tokenization.FullTokenizer(
673
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case
674
- )
675
- processor_text_fn = get_processor_text_fn(False, FLAGS.do_lower_case)
676
- else:
677
- assert FLAGS.tokenization == "SentencePiece"
678
- tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
679
- processor_text_fn = get_processor_text_fn(True, FLAGS.do_lower_case)
680
-
681
- input_files = []
682
- for input_pattern in FLAGS.input_file.split(","):
683
- input_files.extend(tf.io.gfile.glob(input_pattern))
684
-
685
- logging.info("*** Reading from input files ***")
686
- for input_file in input_files:
687
- logging.info(" %s", input_file)
688
-
689
- rng = random.Random(FLAGS.random_seed)
690
- instances = create_training_instances(
691
- input_files,
692
- tokenizer,
693
- processor_text_fn,
694
- FLAGS.max_seq_length,
695
- FLAGS.dupe_factor,
696
- FLAGS.short_seq_prob,
697
- FLAGS.masked_lm_prob,
698
- FLAGS.max_predictions_per_seq,
699
- rng,
700
- FLAGS.do_whole_word_mask,
701
- FLAGS.max_ngram_size,
702
- )
703
-
704
- output_files = FLAGS.output_file.split(",")
705
- logging.info("*** Writing to output files ***")
706
- for output_file in output_files:
707
- logging.info(" %s", output_file)
708
-
709
- write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
710
- FLAGS.max_predictions_per_seq, output_files,
711
- FLAGS.gzip_compress,
712
- FLAGS.use_v2_feature_names)
713
-
714
-
715
- if __name__ == "__main__":
716
- flags.mark_flag_as_required("input_file")
717
- flags.mark_flag_as_required("output_file")
718
- app.run(main)