Pradeep Kumar commited on
Commit
20e1eb7
·
verified ·
1 Parent(s): 01a6c3b

Delete create_finetuning_data.py

Browse files
Files changed (1) hide show
  1. create_finetuning_data.py +0 -441
create_finetuning_data.py DELETED
@@ -1,441 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """BERT finetuning task dataset generator."""
16
-
17
- import functools
18
- import json
19
- import os
20
-
21
- # Import libraries
22
- from absl import app
23
- from absl import flags
24
- import tensorflow as tf, tf_keras
25
- from official.nlp.data import classifier_data_lib
26
- from official.nlp.data import sentence_retrieval_lib
27
- # word-piece tokenizer based squad_lib
28
- from official.nlp.data import squad_lib as squad_lib_wp
29
- # sentence-piece tokenizer based squad_lib
30
- from official.nlp.data import squad_lib_sp
31
- from official.nlp.data import tagging_data_lib
32
- from official.nlp.tools import tokenization
33
-
34
- FLAGS = flags.FLAGS
35
-
36
- flags.DEFINE_enum(
37
- "fine_tuning_task_type", "classification",
38
- ["classification", "regression", "squad", "retrieval", "tagging"],
39
- "The name of the BERT fine tuning task for which data "
40
- "will be generated.")
41
-
42
- # BERT classification specific flags.
43
- flags.DEFINE_string(
44
- "input_data_dir", None,
45
- "The input data dir. Should contain the .tsv files (or other data files) "
46
- "for the task.")
47
-
48
- flags.DEFINE_enum(
49
- "classification_task_name", "MNLI", [
50
- "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
51
- "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
52
- "AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
53
- ], "The name of the task to train BERT classifier. The "
54
- "difference between XTREME-XNLI and XNLI is: 1. the format "
55
- "of input tsv files; 2. the dev set for XTREME is english "
56
- "only and for XNLI is all languages combined. Same for "
57
- "PAWS-X.")
58
-
59
- # MNLI task-specific flag.
60
- flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
61
- "The type of MNLI dataset.")
62
-
63
- # XNLI task-specific flag.
64
- flags.DEFINE_string(
65
- "xnli_language", "en",
66
- "Language of training data for XNLI task. If the value is 'all', the data "
67
- "of all languages will be used for training.")
68
-
69
- # PAWS-X task-specific flag.
70
- flags.DEFINE_string(
71
- "pawsx_language", "en",
72
- "Language of training data for PAWS-X task. If the value is 'all', the data "
73
- "of all languages will be used for training.")
74
-
75
- # XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
76
- flags.DEFINE_string(
77
- "translated_input_data_dir", None,
78
- "The translated input data dir. Should contain the .tsv files (or other "
79
- "data files) for the task.")
80
-
81
- # Retrieval task-specific flags.
82
- flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
83
- "The name of sentence retrieval task for scoring")
84
-
85
- # Tagging task-specific flags.
86
- flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
87
- "The name of BERT tagging (token classification) task.")
88
-
89
- flags.DEFINE_bool("tagging_only_use_en_train", True,
90
- "Whether only use english training data in tagging.")
91
-
92
- # BERT Squad task-specific flags.
93
- flags.DEFINE_string(
94
- "squad_data_file", None,
95
- "The input data file in for generating training data for BERT squad task.")
96
-
97
- flags.DEFINE_string(
98
- "translated_squad_data_folder", None,
99
- "The translated data folder for generating training data for BERT squad "
100
- "task.")
101
-
102
- flags.DEFINE_integer(
103
- "doc_stride", 128,
104
- "When splitting up a long document into chunks, how much stride to "
105
- "take between chunks.")
106
-
107
- flags.DEFINE_integer(
108
- "max_query_length", 64,
109
- "The maximum number of tokens for the question. Questions longer than "
110
- "this will be truncated to this length.")
111
-
112
- flags.DEFINE_bool(
113
- "version_2_with_negative", False,
114
- "If true, the SQuAD examples contain some that do not have an answer.")
115
-
116
- flags.DEFINE_bool(
117
- "xlnet_format", False,
118
- "If true, then data will be preprocessed in a paragraph, query, class order"
119
- " instead of the BERT-style class, paragraph, query order.")
120
-
121
- # XTREME specific flags.
122
- flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
123
-
124
- # Shared flags across BERT fine-tuning tasks.
125
- flags.DEFINE_string("vocab_file", None,
126
- "The vocabulary file that the BERT model was trained on.")
127
-
128
- flags.DEFINE_string(
129
- "train_data_output_path", None,
130
- "The path in which generated training input data will be written as tf"
131
- " records.")
132
-
133
- flags.DEFINE_string(
134
- "eval_data_output_path", None,
135
- "The path in which generated evaluation input data will be written as tf"
136
- " records.")
137
-
138
- flags.DEFINE_string(
139
- "test_data_output_path", None,
140
- "The path in which generated test input data will be written as tf"
141
- " records. If None, do not generate test data. Must be a pattern template"
142
- " as test_{}.tfrecords if processor has language specific test data.")
143
-
144
- flags.DEFINE_string("meta_data_file_path", None,
145
- "The path in which input meta data will be written.")
146
-
147
- flags.DEFINE_bool(
148
- "do_lower_case", True,
149
- "Whether to lower case the input text. Should be True for uncased "
150
- "models and False for cased models.")
151
-
152
- flags.DEFINE_integer(
153
- "max_seq_length", 128,
154
- "The maximum total input sequence length after WordPiece tokenization. "
155
- "Sequences longer than this will be truncated, and sequences shorter "
156
- "than this will be padded.")
157
-
158
- flags.DEFINE_string("sp_model_file", "",
159
- "The path to the model used by sentence piece tokenizer.")
160
-
161
- flags.DEFINE_enum(
162
- "tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
163
- "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
164
- "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
165
- "while ALBERT uses SentencePiece tokenizer.")
166
-
167
- flags.DEFINE_string(
168
- "tfds_params", "", "Comma-separated list of TFDS parameter assignments for "
169
- "generic classfication data import (for more details "
170
- "see the TfdsProcessor class documentation).")
171
-
172
-
173
- def generate_classifier_dataset():
174
- """Generates classifier dataset and returns input meta data."""
175
- if FLAGS.classification_task_name in [
176
- "COLA",
177
- "WNLI",
178
- "SST-2",
179
- "MRPC",
180
- "QQP",
181
- "STS-B",
182
- "MNLI",
183
- "QNLI",
184
- "RTE",
185
- "AX",
186
- "SUPERGLUE-RTE",
187
- "CB",
188
- "BoolQ",
189
- "WIC",
190
- ]:
191
- assert not FLAGS.input_data_dir or FLAGS.tfds_params
192
- else:
193
- assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
194
- FLAGS.tfds_params)
195
-
196
- if FLAGS.tokenization == "WordPiece":
197
- tokenizer = tokenization.FullTokenizer(
198
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
199
- processor_text_fn = tokenization.convert_to_unicode
200
- else:
201
- assert FLAGS.tokenization == "SentencePiece"
202
- tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
203
- processor_text_fn = functools.partial(
204
- tokenization.preprocess_text, lower=FLAGS.do_lower_case)
205
-
206
- if FLAGS.tfds_params:
207
- processor = classifier_data_lib.TfdsProcessor(
208
- tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
209
- return classifier_data_lib.generate_tf_record_from_data_file(
210
- processor,
211
- None,
212
- tokenizer,
213
- train_data_output_path=FLAGS.train_data_output_path,
214
- eval_data_output_path=FLAGS.eval_data_output_path,
215
- test_data_output_path=FLAGS.test_data_output_path,
216
- max_seq_length=FLAGS.max_seq_length)
217
- else:
218
- processors = {
219
- "ax":
220
- classifier_data_lib.AxProcessor,
221
- "cola":
222
- classifier_data_lib.ColaProcessor,
223
- "imdb":
224
- classifier_data_lib.ImdbProcessor,
225
- "mnli":
226
- functools.partial(
227
- classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
228
- "mrpc":
229
- classifier_data_lib.MrpcProcessor,
230
- "qnli":
231
- classifier_data_lib.QnliProcessor,
232
- "qqp":
233
- classifier_data_lib.QqpProcessor,
234
- "rte":
235
- classifier_data_lib.RteProcessor,
236
- "sst-2":
237
- classifier_data_lib.SstProcessor,
238
- "sts-b":
239
- classifier_data_lib.StsBProcessor,
240
- "xnli":
241
- functools.partial(
242
- classifier_data_lib.XnliProcessor,
243
- language=FLAGS.xnli_language),
244
- "paws-x":
245
- functools.partial(
246
- classifier_data_lib.PawsxProcessor,
247
- language=FLAGS.pawsx_language),
248
- "wnli":
249
- classifier_data_lib.WnliProcessor,
250
- "xtreme-xnli":
251
- functools.partial(
252
- classifier_data_lib.XtremeXnliProcessor,
253
- translated_data_dir=FLAGS.translated_input_data_dir,
254
- only_use_en_dev=FLAGS.only_use_en_dev),
255
- "xtreme-paws-x":
256
- functools.partial(
257
- classifier_data_lib.XtremePawsxProcessor,
258
- translated_data_dir=FLAGS.translated_input_data_dir,
259
- only_use_en_dev=FLAGS.only_use_en_dev),
260
- "ax-g":
261
- classifier_data_lib.AXgProcessor,
262
- "superglue-rte":
263
- classifier_data_lib.SuperGLUERTEProcessor,
264
- "cb":
265
- classifier_data_lib.CBProcessor,
266
- "boolq":
267
- classifier_data_lib.BoolQProcessor,
268
- "wic":
269
- classifier_data_lib.WnliProcessor,
270
- }
271
- task_name = FLAGS.classification_task_name.lower()
272
- if task_name not in processors:
273
- raise ValueError("Task not found: %s" % (task_name,))
274
-
275
- processor = processors[task_name](process_text_fn=processor_text_fn)
276
- return classifier_data_lib.generate_tf_record_from_data_file(
277
- processor,
278
- FLAGS.input_data_dir,
279
- tokenizer,
280
- train_data_output_path=FLAGS.train_data_output_path,
281
- eval_data_output_path=FLAGS.eval_data_output_path,
282
- test_data_output_path=FLAGS.test_data_output_path,
283
- max_seq_length=FLAGS.max_seq_length)
284
-
285
-
286
- def generate_regression_dataset():
287
- """Generates regression dataset and returns input meta data."""
288
- if FLAGS.tokenization == "WordPiece":
289
- tokenizer = tokenization.FullTokenizer(
290
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
291
- processor_text_fn = tokenization.convert_to_unicode
292
- else:
293
- assert FLAGS.tokenization == "SentencePiece"
294
- tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
295
- processor_text_fn = functools.partial(
296
- tokenization.preprocess_text, lower=FLAGS.do_lower_case)
297
-
298
- if FLAGS.tfds_params:
299
- processor = classifier_data_lib.TfdsProcessor(
300
- tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
301
- return classifier_data_lib.generate_tf_record_from_data_file(
302
- processor,
303
- None,
304
- tokenizer,
305
- train_data_output_path=FLAGS.train_data_output_path,
306
- eval_data_output_path=FLAGS.eval_data_output_path,
307
- test_data_output_path=FLAGS.test_data_output_path,
308
- max_seq_length=FLAGS.max_seq_length)
309
- else:
310
- raise ValueError("No data processor found for the given regression task.")
311
-
312
-
313
- def generate_squad_dataset():
314
- """Generates squad training dataset and returns input meta data."""
315
- assert FLAGS.squad_data_file
316
- if FLAGS.tokenization == "WordPiece":
317
- return squad_lib_wp.generate_tf_record_from_json_file(
318
- input_file_path=FLAGS.squad_data_file,
319
- vocab_file_path=FLAGS.vocab_file,
320
- output_path=FLAGS.train_data_output_path,
321
- translated_input_folder=FLAGS.translated_squad_data_folder,
322
- max_seq_length=FLAGS.max_seq_length,
323
- do_lower_case=FLAGS.do_lower_case,
324
- max_query_length=FLAGS.max_query_length,
325
- doc_stride=FLAGS.doc_stride,
326
- version_2_with_negative=FLAGS.version_2_with_negative,
327
- xlnet_format=FLAGS.xlnet_format)
328
- else:
329
- assert FLAGS.tokenization == "SentencePiece"
330
- return squad_lib_sp.generate_tf_record_from_json_file(
331
- input_file_path=FLAGS.squad_data_file,
332
- sp_model_file=FLAGS.sp_model_file,
333
- output_path=FLAGS.train_data_output_path,
334
- translated_input_folder=FLAGS.translated_squad_data_folder,
335
- max_seq_length=FLAGS.max_seq_length,
336
- do_lower_case=FLAGS.do_lower_case,
337
- max_query_length=FLAGS.max_query_length,
338
- doc_stride=FLAGS.doc_stride,
339
- xlnet_format=FLAGS.xlnet_format,
340
- version_2_with_negative=FLAGS.version_2_with_negative)
341
-
342
-
343
- def generate_retrieval_dataset():
344
- """Generate retrieval test and dev dataset and returns input meta data."""
345
- assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
346
- if FLAGS.tokenization == "WordPiece":
347
- tokenizer = tokenization.FullTokenizer(
348
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
349
- processor_text_fn = tokenization.convert_to_unicode
350
- else:
351
- assert FLAGS.tokenization == "SentencePiece"
352
- tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
353
- processor_text_fn = functools.partial(
354
- tokenization.preprocess_text, lower=FLAGS.do_lower_case)
355
-
356
- processors = {
357
- "bucc": sentence_retrieval_lib.BuccProcessor,
358
- "tatoeba": sentence_retrieval_lib.TatoebaProcessor,
359
- }
360
-
361
- task_name = FLAGS.retrieval_task_name.lower()
362
- if task_name not in processors:
363
- raise ValueError("Task not found: %s" % task_name)
364
-
365
- processor = processors[task_name](process_text_fn=processor_text_fn)
366
-
367
- return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
368
- processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
369
- FLAGS.test_data_output_path, FLAGS.max_seq_length)
370
-
371
-
372
- def generate_tagging_dataset():
373
- """Generates tagging dataset."""
374
- processors = {
375
- "panx":
376
- functools.partial(
377
- tagging_data_lib.PanxProcessor,
378
- only_use_en_train=FLAGS.tagging_only_use_en_train,
379
- only_use_en_dev=FLAGS.only_use_en_dev),
380
- "udpos":
381
- functools.partial(
382
- tagging_data_lib.UdposProcessor,
383
- only_use_en_train=FLAGS.tagging_only_use_en_train,
384
- only_use_en_dev=FLAGS.only_use_en_dev),
385
- }
386
- task_name = FLAGS.tagging_task_name.lower()
387
- if task_name not in processors:
388
- raise ValueError("Task not found: %s" % task_name)
389
-
390
- if FLAGS.tokenization == "WordPiece":
391
- tokenizer = tokenization.FullTokenizer(
392
- vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
393
- processor_text_fn = tokenization.convert_to_unicode
394
- elif FLAGS.tokenization == "SentencePiece":
395
- tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
396
- processor_text_fn = functools.partial(
397
- tokenization.preprocess_text, lower=FLAGS.do_lower_case)
398
- else:
399
- raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
400
-
401
- processor = processors[task_name]()
402
- return tagging_data_lib.generate_tf_record_from_data_file(
403
- processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
404
- FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
405
- FLAGS.test_data_output_path, processor_text_fn)
406
-
407
-
408
- def main(_):
409
- if FLAGS.tokenization == "WordPiece":
410
- if not FLAGS.vocab_file:
411
- raise ValueError(
412
- "FLAG vocab_file for word-piece tokenizer is not specified.")
413
- else:
414
- assert FLAGS.tokenization == "SentencePiece"
415
- if not FLAGS.sp_model_file:
416
- raise ValueError(
417
- "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
418
-
419
- if FLAGS.fine_tuning_task_type != "retrieval":
420
- flags.mark_flag_as_required("train_data_output_path")
421
-
422
- if FLAGS.fine_tuning_task_type == "classification":
423
- input_meta_data = generate_classifier_dataset()
424
- elif FLAGS.fine_tuning_task_type == "regression":
425
- input_meta_data = generate_regression_dataset()
426
- elif FLAGS.fine_tuning_task_type == "retrieval":
427
- input_meta_data = generate_retrieval_dataset()
428
- elif FLAGS.fine_tuning_task_type == "squad":
429
- input_meta_data = generate_squad_dataset()
430
- else:
431
- assert FLAGS.fine_tuning_task_type == "tagging"
432
- input_meta_data = generate_tagging_dataset()
433
-
434
- tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
435
- with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
436
- writer.write(json.dumps(input_meta_data, indent=4) + "\n")
437
-
438
-
439
- if __name__ == "__main__":
440
- flags.mark_flag_as_required("meta_data_file_path")
441
- app.run(main)