Pradeep Kumar commited on
Commit
439fdba
·
verified ·
1 Parent(s): be3c9fb

Delete create_xlnet_pretraining_data.py

Browse files
Files changed (1) hide show
  1. create_xlnet_pretraining_data.py +0 -721
create_xlnet_pretraining_data.py DELETED
@@ -1,721 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Create LM TF examples for XLNet."""
16
-
17
- import dataclasses
18
- import json
19
- import math
20
- import os
21
-
22
- import random
23
- from typing import Iterable, Mapping, List, Optional, Tuple
24
- import unicodedata
25
-
26
- # Import libraries
27
-
28
- from absl import app
29
- from absl import flags
30
- from absl import logging
31
-
32
- import numpy as np
33
- import tensorflow as tf, tf_keras
34
-
35
- from official.nlp.tools import tokenization
36
-
37
- special_symbols = {
38
- "<unk>": 0,
39
- "<s>": 1,
40
- "</s>": 2,
41
- "<cls>": 3,
42
- "<sep>": 4,
43
- "<pad>": 5,
44
- "<mask>": 6,
45
- "<eod>": 7,
46
- "<eop>": 8,
47
- }
48
-
49
- FLAGS = flags.FLAGS
50
-
51
- flags.DEFINE_integer("seq_length", 512,
52
- help="Sequence length.")
53
- flags.DEFINE_integer("reuse_length", 256,
54
- help="Number of token that can be reused as memory. "
55
- "Could be half of `seq_len`.")
56
- flags.DEFINE_string("input_file", None,
57
- "Input raw text file (or comma-separated list of files).")
58
- flags.DEFINE_string(
59
- "save_dir", None,
60
- "Directory for saving processed data.")
61
- flags.DEFINE_string("sp_model_file", "",
62
- "The path to the model used by sentence piece tokenizer.")
63
- flags.DEFINE_bool("use_eod_token", True,
64
- "Whether or not to include EOD tokens.")
65
- flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.")
66
- flags.DEFINE_bool(
67
- "do_lower_case", True,
68
- "Whether to lower case the input text. Should be True for uncased "
69
- "models and False for cased models.")
70
- flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.")
71
- flags.DEFINE_integer("num_cores_per_host", 16,
72
- "The number of (TPU) cores per host.")
73
- flags.DEFINE_string("prefix", "", "Filename prefix.")
74
- flags.DEFINE_string("suffix", "", "Filename suffix.")
75
-
76
- flags.DEFINE_integer("task_id", None,
77
- "The id of the current task.")
78
- flags.DEFINE_integer("num_tasks", None,
79
- "The total number of tasks.")
80
- flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.")
81
-
82
-
83
- @dataclasses.dataclass
84
- class TrainingInstance:
85
- """Representation of a single XLNet Pretraining instance."""
86
- data: Iterable[int]
87
- segment_ids: Iterable[int]
88
- boundary_indices: Iterable[int]
89
- label: int
90
-
91
- def to_feature(self) -> Mapping[str, tf.train.Feature]:
92
- feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x))
93
- return dict(
94
- input_word_ids=feat(self.data),
95
- input_type_ids=feat(self.segment_ids),
96
- boundary_indices=feat(self.boundary_indices),
97
- label=feat([self.label]))
98
-
99
- def to_example(self) -> tf.train.Example:
100
- return tf.train.Example(
101
- features=tf.train.Features(feature=self.to_feature()))
102
-
103
- def __str__(self):
104
- def seq_to_str(seq):
105
- return " ".join([str(x) for x in seq])
106
-
107
- s = ""
108
- s += "tokens: %s\n" % seq_to_str(self.data)
109
- s += "segment_ids: %s\n" % seq_to_str(self.segment_ids)
110
- s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices)
111
- s += "label: %s\n" % self.label
112
- s += "\n"
113
- return s
114
-
115
- def __repr__(self):
116
- return self.__str__()
117
-
118
-
119
- def _preprocess_line(line: str, do_lower_case: bool = False) -> str:
120
- """Preprocesses an individual raw text line.
121
-
122
- This function will:
123
- - Remove extraneous spaces.
124
- - Replace `` with ", and '' with ".
125
- - Replaces accents.
126
- - Applies lower casing.
127
-
128
- Args:
129
- line: The input line to preprocess.
130
- do_lower_case: Whether or not to lower case the text.
131
-
132
- Returns:
133
- The preprocessed line.
134
-
135
- """
136
- line = " ".join(line.split())
137
- line = line.replace("``", "\"").replace("''", "\"")
138
-
139
- # Replace accents.
140
- line = unicodedata.normalize("NFKD", line)
141
- line = "".join([c for c in line if not unicodedata.combining(c)])
142
-
143
- if do_lower_case:
144
- line = line.lower()
145
- return line
146
-
147
-
148
- def preprocess_and_tokenize_input_files(
149
- input_files: Iterable[str],
150
- tokenizer: tokenization.FullSentencePieceTokenizer,
151
- use_eod: bool = True,
152
- do_lower_case: bool = False,
153
- log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]:
154
- """Preprocesses and encodes raw text from input files.
155
-
156
- This function preprocesses raw text and encodes them into tokens using a
157
- `SentencePieceModel` tokenization method. This also provides the sentence
158
- indicator for each token.
159
-
160
- Args:
161
- input_files: The list of input file names.
162
- tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
163
- use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
164
- not included.
165
- do_lower_case: Whether or not to apply lower casing during raw text
166
- preprocessing.
167
- log_example_freq: The optional field for how many lines to process before
168
- emitting an info log.
169
-
170
- Returns:
171
- The preprocessed list. Each entry in the list is a tuple consisting of
172
- the token IDs and the sentence IDs.
173
-
174
- """
175
- all_data = []
176
- eod_symbol = special_symbols["<eod>"]
177
-
178
- total_number_of_lines = 0
179
-
180
- # Input file format:
181
- # (1) One sentence per line. These should ideally be actual sentences, not
182
- # entire paragraphs or arbitrary spans of text. (Because we use the
183
- # sentence boundaries for the "next sentence prediction" task).
184
- # (2) Blank lines between documents. Document boundaries are needed so
185
- # that the "next sentence prediction" task doesn't span between documents.
186
- for input_file in input_files:
187
- line_count = 0
188
- logging.info("Preprocessing %s", input_file)
189
-
190
- all_tokens = []
191
- all_sentence_ids = []
192
-
193
- sentence_id = True
194
-
195
- with tf.io.gfile.GFile(input_file, "rb") as reader:
196
- while True:
197
- line = tokenization.convert_to_unicode(reader.readline())
198
- if not line:
199
- break
200
-
201
- line_count += 1
202
- if line_count % log_example_freq == 0:
203
- logging.info("Loading line %d", line_count)
204
-
205
- line = line.strip()
206
-
207
- if not line:
208
- if use_eod:
209
- token_ids = [eod_symbol]
210
- sentence_id = not sentence_id
211
- else:
212
- continue
213
- else:
214
- preprocessed_line = _preprocess_line(
215
- line=line, do_lower_case=do_lower_case)
216
- token_ids = tokenization.encode_ids(
217
- sp_model=tokenizer.sp_model, text=preprocessed_line)
218
-
219
- all_tokens.extend(token_ids)
220
- all_sentence_ids.extend([sentence_id] * len(token_ids))
221
- sentence_id = not sentence_id
222
- logging.info("Finished processing %s. Number of lines: %d",
223
- input_file, line_count)
224
- if line_count == 0:
225
- continue
226
- total_number_of_lines += line_count
227
- all_tokens = np.array(all_tokens, dtype=np.int64)
228
- all_sentence_ids = np.array(all_sentence_ids, dtype=bool)
229
- all_data.append((all_tokens, all_sentence_ids))
230
-
231
- logging.info("Completed text preprocessing. Total number of lines: %d",
232
- total_number_of_lines)
233
- return all_data
234
-
235
-
236
- def _reshape_to_batch_dimensions(
237
- tokens: np.array,
238
- sentence_ids: np.array,
239
- per_host_batch_size: int) -> Tuple[np.array, np.array]:
240
- """Truncates and reshapes input data with a batch major dimension.
241
-
242
- Args:
243
- tokens: The input token ids. This should have the same shape as
244
- `sentence_ids`.
245
- sentence_ids: The input sentence ids. This should have the same shape as
246
- `token_ids`.
247
- per_host_batch_size: The target per-host batch size.
248
-
249
- Returns:
250
- The tuple of reshaped tokens and sentence_ids.
251
- """
252
- num_steps = len(tokens) // per_host_batch_size
253
- truncated_data_length = num_steps * per_host_batch_size
254
-
255
- logging.info("per_host_batch_size: %d", per_host_batch_size)
256
- logging.info("num_steps: %d", num_steps)
257
- def truncate_and_reshape(a):
258
- return a[:truncated_data_length].reshape((per_host_batch_size, num_steps))
259
-
260
- return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids))
261
-
262
-
263
- def _create_a_and_b_segments(
264
- tokens: np.array,
265
- sentence_ids: np.array,
266
- begin_index: int,
267
- total_length: int,
268
- no_cut_probability: float = 0.5):
269
- """Splits segments A and B from a single instance of tokens and sentence ids.
270
-
271
- Args:
272
- tokens: The 1D input token ids. This represents an individual entry within a
273
- batch.
274
- sentence_ids: The 1D input sentence ids. This represents an individual entry
275
- within a batch. This should be the same length as `tokens`.
276
- begin_index: The reference beginning index to split data.
277
- total_length: The target combined length of segments A and B.
278
- no_cut_probability: The probability of not cutting a segment despite
279
- a cut possibly existing.
280
-
281
- Returns:
282
- A tuple consisting of A data, B data, and label.
283
-
284
- """
285
- data_length = tokens.shape[0]
286
- if begin_index + total_length >= data_length:
287
- logging.info("[_create_segments]: begin_index %d + total_length %d >= "
288
- "data_length %d", begin_index, total_length, data_length)
289
- return None
290
-
291
- end_index = begin_index + 1
292
- cut_indices = []
293
-
294
- # Identify all indices where sentence IDs change from one to the next.
295
- while end_index < data_length:
296
- if sentence_ids[end_index] != sentence_ids[end_index - 1]:
297
- if end_index - begin_index >= total_length:
298
- break
299
- cut_indices.append(end_index)
300
- end_index += 1
301
-
302
- a_begin = begin_index
303
-
304
- if not cut_indices or random.random() < no_cut_probability:
305
- # Segments A and B are contained within the same sentence.
306
- label = 0
307
- if not cut_indices:
308
- a_end = end_index
309
- else:
310
- a_end = random.choice(cut_indices)
311
- b_length = max(1, total_length - (a_end - a_begin))
312
- b_begin = random.randint(0, data_length - 1 - b_length)
313
- b_end = b_begin + b_length
314
-
315
- while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]:
316
- b_begin -= 1
317
- while (b_end < data_length - 1 and
318
- sentence_ids[b_end - 1] == sentence_ids[b_end]):
319
- b_end += 1
320
- else:
321
- # Segments A and B are different sentences.
322
- label = 1
323
- a_end = random.choice(cut_indices)
324
- b_begin = a_end
325
- b_end = end_index
326
-
327
- while a_end - a_begin + b_end - b_begin > total_length:
328
- if a_end - a_begin > b_end - b_begin:
329
- # Delete only the right side for the LM objective.
330
- a_end -= 1
331
- else:
332
- b_end -= 1
333
- if a_end >= data_length or b_end >= data_length:
334
- logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d",
335
- a_end, b_end, data_length)
336
- return None
337
-
338
- a_data = tokens[a_begin: a_end]
339
- b_data = tokens[b_begin: b_end]
340
- return a_data, b_data, label
341
-
342
-
343
- def _is_functional_piece(piece: str) -> bool:
344
- return piece != "<unk>" and piece.startswith("<") and piece.endswith(">")
345
-
346
-
347
- def _is_start_piece(piece: str) -> bool:
348
- special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
349
- if (piece.startswith("▁") or piece in special_pieces):
350
- return True
351
- else:
352
- return False
353
-
354
-
355
- def _get_boundary_indices(
356
- data: np.array,
357
- tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array:
358
- """Gets the boundary indices of whole words."""
359
- seq_length = len(data)
360
- boundary_indices = []
361
- for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())):
362
- if _is_start_piece(piece) and not _is_functional_piece(piece):
363
- boundary_indices.append(index)
364
- boundary_indices.append(seq_length)
365
- return boundary_indices
366
-
367
-
368
- def _convert_tokens_to_instances(
369
- tokens: np.array,
370
- sentence_ids: np.array,
371
- per_host_batch_size: int,
372
- seq_length: int,
373
- reuse_length: int,
374
- bi_data: bool,
375
- tokenizer: tokenization.FullSentencePieceTokenizer,
376
- num_cores_per_host: int = 0,
377
- logging_frequency: int = 500) -> List[TrainingInstance]:
378
- """Converts tokens and sentence IDs into individual training instances.
379
-
380
- The format of data in the XLNet pretraining task is very similar to the
381
- BERT pretraining task. Two segments A and B are randomly sampled, and the
382
- contatenation of A and B into a single sequence is used to perform
383
- language modeling.
384
-
385
- To create an XLNet Pretraining instance from a single long sequence, S:
386
- - Create a segment of length `reuse_length`. This first segment represents
387
- past tokens. During modeling, this segment is used to cache obtained
388
- content representations for the segment recurrence mechanism.
389
- - Similar to BERT, create a segment of length `seq_length` - `reuse_length`
390
- composed of A and B segments.
391
- For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
392
-
393
- Args:
394
- tokens: All tokens concatenated into a single list.
395
- sentence_ids: All sentence IDs concatenated into a single list.
396
- per_host_batch_size: The target batch size per host.
397
- seq_length: The max sequence length.
398
- reuse_length: The number of tokens to use from the previous segment.
399
- bi_data: Whether or not to use bidirectional data.
400
- tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
401
- num_cores_per_host: The number of cores per host. This is required if
402
- `bi_data` = `True`.
403
- logging_frequency: The frequency at which to log status updates.
404
-
405
- Returns:
406
- A list of `TrainingInstance` objects.
407
- """
408
- instances = []
409
-
410
- per_core_batch_size = (per_host_batch_size // num_cores_per_host
411
- if bi_data else None)
412
-
413
- if bi_data:
414
- logging.info("Bi-directional data enabled.")
415
- assert per_host_batch_size % (2 * num_cores_per_host) == 0
416
- forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions(
417
- tokens=tokens,
418
- sentence_ids=sentence_ids,
419
- per_host_batch_size=per_host_batch_size // 2)
420
- forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1)
421
-
422
- forward_tokens = forward_tokens.reshape(forward_data_shape)
423
- forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape)
424
-
425
- backwards_tokens = forward_tokens[:, :, :, ::-1]
426
- backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1]
427
-
428
- tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape(
429
- per_host_batch_size, -1)
430
- sentence_ids = np.concatenate(
431
- [forward_sentence_ids, backwards_sentence_ids]).reshape(
432
- per_host_batch_size, -1)
433
- else:
434
- logging.info("Bi-directional data disabled.")
435
- tokens, sentence_ids = _reshape_to_batch_dimensions(
436
- tokens=tokens,
437
- sentence_ids=sentence_ids,
438
- per_host_batch_size=per_host_batch_size)
439
-
440
- logging.info("Tokens shape: %s", tokens.shape)
441
-
442
- data_length = tokens.shape[1]
443
- sep = np.array([special_symbols["<sep>"]], dtype=np.int64)
444
- cls = np.array([special_symbols["<cls>"]], dtype=np.int64)
445
- # 2 sep, 1 cls
446
- num_special_tokens = 3
447
-
448
- data_index = 0
449
- batch_number = 0
450
- step_size = reuse_length if reuse_length else seq_length
451
- num_batches = math.ceil(data_length / step_size)
452
-
453
- while data_index + seq_length <= data_length:
454
- if batch_number % logging_frequency == 0:
455
- logging.info("Processing batch %d of %d", batch_number, num_batches)
456
-
457
- for batch_index in range(per_host_batch_size):
458
- previous_segment_tokens = tokens[
459
- batch_index, data_index: data_index + reuse_length]
460
-
461
- results = _create_a_and_b_segments(
462
- tokens=tokens[batch_index],
463
- sentence_ids=sentence_ids[batch_index],
464
- begin_index=data_index + reuse_length,
465
- total_length=seq_length - reuse_length - num_special_tokens)
466
-
467
- if results is None:
468
- logging.info("Stopping at data index: %d", data_index)
469
- break
470
- a_data, b_data, label = results
471
-
472
- data = np.concatenate(
473
- [previous_segment_tokens, a_data, sep, b_data, sep, cls])
474
- a_length = a_data.shape[0]
475
- b_length = b_data.shape[0]
476
- segment_ids = ([0] * (reuse_length + a_length) + [0]
477
- + [1] * b_length + [1] + [2])
478
- boundary_indices = _get_boundary_indices(tokenizer=tokenizer,
479
- data=data)
480
- assert len(data) == seq_length
481
- assert len(segment_ids) == seq_length
482
- assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test
483
-
484
- instances.append(TrainingInstance(
485
- data=data,
486
- segment_ids=segment_ids,
487
- boundary_indices=boundary_indices,
488
- label=label))
489
- batch_number += 1
490
- data_index += step_size
491
- return instances
492
-
493
-
494
- def write_instances_to_tfrecord(
495
- instances: Iterable[TrainingInstance],
496
- save_path: str):
497
- """Writes instances to TFRecord."""
498
- record_writer = tf.io.TFRecordWriter(save_path)
499
- logging.info("Start writing to %s.", save_path)
500
-
501
- for i, instance in enumerate(instances):
502
- if i < 5:
503
- logging.info("Instance %d: %s", i, str(instance))
504
- record_writer.write(instance.to_example().SerializeToString())
505
-
506
- record_writer.close()
507
- logging.info("Done writing %s.", save_path)
508
-
509
-
510
- def shuffle_and_combine_preprocessed_data(
511
- all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]:
512
- """Shuffles and combines preprocessed token/sentence IDs from documents."""
513
- document_permutation = np.random.permutation(len(all_data))
514
-
515
- previous_sentence_id = None
516
-
517
- all_tokens, all_sentence_ids = [], []
518
- for document_index in document_permutation:
519
- tokens, sentence_ids = all_data[document_index]
520
- # pylint: disable=g-explicit-length-test
521
- if len(tokens) == 0:
522
- continue
523
- if (previous_sentence_id is not None and
524
- sentence_ids[0] == previous_sentence_id):
525
- sentence_ids = np.logical_not(sentence_ids)
526
-
527
- all_tokens.append(tokens)
528
- all_sentence_ids.append(sentence_ids)
529
-
530
- previous_sentence_id = sentence_ids[-1]
531
-
532
- return np.concatenate(all_tokens), np.concatenate(all_sentence_ids)
533
-
534
-
535
- def get_tfrecord_name(
536
- per_host_batch_size: int,
537
- num_cores_per_host: int,
538
- seq_length: int,
539
- bi_data: bool,
540
- reuse_length: int,
541
- do_lower_case: bool,
542
- use_eod_token: bool,
543
- prefix: str = "",
544
- suffix: str = "",
545
- pass_id: int = 0,
546
- num_passes: int = 1,
547
- task_id: int = None,
548
- num_tasks: int = None) -> str:
549
- """Formats the resulting TFRecord name based on provided inputs."""
550
- components = []
551
- if prefix:
552
- components.append(prefix)
553
- components.append("seqlen-{}".format(seq_length))
554
- if reuse_length == 0:
555
- components.append("memless")
556
- else:
557
- components.append("reuse-{}".format(reuse_length))
558
- components.append("bs-{}".format(per_host_batch_size))
559
- components.append("cores-{}".format(num_cores_per_host))
560
-
561
- if do_lower_case:
562
- components.append("uncased")
563
- else:
564
- components.append("cased")
565
- if use_eod_token:
566
- components.append("eod")
567
- if bi_data:
568
- components.append("bi")
569
- else:
570
- components.append("uni")
571
-
572
- if suffix:
573
- components.append(suffix)
574
-
575
- s = "_".join(components) + ".tfrecord"
576
- if num_passes == 1 and task_id is None:
577
- return s
578
-
579
- if task_id is None:
580
- num_tasks = 1
581
- task_id = 0
582
-
583
- current_shard = task_id * num_passes + pass_id
584
- total_shards = num_tasks * num_passes
585
- return s + "-{}-of-{}".format(current_shard, total_shards)
586
-
587
-
588
- def create_tfrecords(
589
- tokenizer: tokenization.FullSentencePieceTokenizer,
590
- input_file_or_files: str,
591
- use_eod_token: bool,
592
- do_lower_case: bool,
593
- per_host_batch_size: int,
594
- seq_length: int,
595
- reuse_length: int,
596
- bi_data: bool,
597
- num_cores_per_host: int,
598
- save_dir: str,
599
- prefix: str = "",
600
- suffix: str = "",
601
- num_tasks: Optional[int] = None,
602
- task_id: Optional[int] = None,
603
- num_passes: int = 1):
604
- """Runs the end-to-end preprocessing pipeline."""
605
-
606
- logging.info("Input configuration:")
607
- logging.info("input file(s): %s", input_file_or_files)
608
- logging.info("use_eod_token: %s", use_eod_token)
609
- logging.info("do_lower_case: %s", do_lower_case)
610
- logging.info("per_host_batch_size: %d", per_host_batch_size)
611
- logging.info("seq_length: %d", seq_length)
612
- logging.info("reuse_length: %d", reuse_length)
613
- logging.info("bi_data: %s", bi_data)
614
- logging.info("num_cores_per_host: %d", num_cores_per_host)
615
- logging.info("save_dir: %s", save_dir)
616
- if task_id is not None and num_tasks is not None:
617
- logging.info("task_id: %d", task_id)
618
- logging.info("num_tasks: %d", num_tasks)
619
-
620
- input_files = []
621
- for input_pattern in input_file_or_files.split(","):
622
- input_files.extend(tf.io.gfile.glob(input_pattern))
623
-
624
- logging.info("*** Reading from input files ***")
625
- for input_file in input_files:
626
- logging.info(" %s", input_file)
627
-
628
- logging.info("Shuffling the files with a fixed random seed.")
629
- np.random.shuffle(input_files)
630
- if num_tasks is not None:
631
- assert task_id is not None
632
- logging.info("Total number of input files: %d", len(input_files))
633
- logging.info("Splitting into %d shards of %d files each.",
634
- num_tasks, len(input_files) // num_tasks)
635
- input_files = input_files[task_id::num_tasks]
636
-
637
- all_data = preprocess_and_tokenize_input_files(
638
- input_files=input_files,
639
- tokenizer=tokenizer,
640
- use_eod=use_eod_token,
641
- do_lower_case=do_lower_case)
642
- for pass_id in range(num_passes):
643
- logging.info("Beginning pass %d of %d", pass_id, num_passes)
644
- tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data)
645
-
646
- assert len(tokens) == len(sentence_ids)
647
-
648
- filename = get_tfrecord_name(
649
- per_host_batch_size=per_host_batch_size,
650
- num_cores_per_host=num_cores_per_host,
651
- seq_length=seq_length,
652
- bi_data=bi_data,
653
- use_eod_token=use_eod_token,
654
- reuse_length=reuse_length,
655
- do_lower_case=do_lower_case,
656
- prefix=prefix,
657
- suffix=suffix,
658
- pass_id=pass_id,
659
- num_passes=num_passes,
660
- num_tasks=num_tasks,
661
- task_id=task_id)
662
- save_path = os.path.join(save_dir, filename)
663
- if os.path.exists(save_path):
664
- # If the path already exists, then we were probably preempted but
665
- # previously wrote this file.
666
- logging.info("%s already exists, skipping this batch.", save_path)
667
- else:
668
- instances = _convert_tokens_to_instances(
669
- tokenizer=tokenizer,
670
- tokens=tokens,
671
- sentence_ids=sentence_ids,
672
- per_host_batch_size=per_host_batch_size,
673
- seq_length=seq_length,
674
- reuse_length=reuse_length,
675
- bi_data=bi_data,
676
- num_cores_per_host=num_cores_per_host)
677
- write_instances_to_tfrecord(instances=instances, save_path=save_path)
678
-
679
- if task_id is None or task_id == 0:
680
- corpus_info = {
681
- "vocab_size": 32000,
682
- "per_host_batch_size": per_host_batch_size,
683
- "num_cores_per_host": num_cores_per_host,
684
- "seq_length": seq_length,
685
- "reuse_length": reuse_length,
686
- "do_lower_case": do_lower_case,
687
- "bi_data": bi_data,
688
- "use_eod_token": use_eod_token,
689
- }
690
- corpus_fname = os.path.basename(filename) + ".json"
691
- corpus_destination = os.path.join(save_dir, corpus_fname)
692
- logging.info("Saving corpus info to %s", corpus_destination)
693
-
694
- with tf.io.gfile.GFile(corpus_destination, "w") as fp:
695
- json.dump(corpus_info, fp)
696
-
697
-
698
- def main(_):
699
- tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
700
- create_tfrecords(
701
- tokenizer=tokenizer,
702
- input_file_or_files=FLAGS.input_file,
703
- use_eod_token=FLAGS.use_eod_token,
704
- do_lower_case=FLAGS.do_lower_case,
705
- per_host_batch_size=FLAGS.per_host_batch_size,
706
- seq_length=FLAGS.seq_length,
707
- reuse_length=FLAGS.reuse_length,
708
- bi_data=FLAGS.bi_data,
709
- num_cores_per_host=FLAGS.num_cores_per_host,
710
- save_dir=FLAGS.save_dir,
711
- prefix=FLAGS.prefix,
712
- suffix=FLAGS.suffix,
713
- num_tasks=FLAGS.num_tasks,
714
- task_id=FLAGS.task_id,
715
- num_passes=FLAGS.num_passes)
716
-
717
-
718
- if __name__ == "__main__":
719
- np.random.seed(0)
720
- logging.set_verbosity(logging.INFO)
721
- app.run(main)