# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script to pre-process classification data into tfrecords.""" import collections import csv import os # Import libraries from absl import app from absl import flags from absl import logging import numpy as np import tensorflow as tf, tf_keras import sentencepiece as spm from official.legacy.xlnet import classifier_utils from official.legacy.xlnet import preprocess_utils flags.DEFINE_bool( "overwrite_data", default=False, help="If False, will use cached data if available.") flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.") flags.DEFINE_string( "spiece_model_file", default="", help="Sentence Piece model path.") flags.DEFINE_string("data_dir", default="", help="Directory for input data.") # task specific flags.DEFINE_string("eval_split", default="dev", help="could be dev or test") flags.DEFINE_string("task_name", default=None, help="Task name") flags.DEFINE_integer( "eval_batch_size", default=64, help="batch size for evaluation") flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length") flags.DEFINE_integer( "num_passes", default=1, help="Num passes for processing training data. " "This is use to batch data without loss for TPUs.") flags.DEFINE_bool("uncased", default=False, help="Use uncased.") flags.DEFINE_bool( "is_regression", default=False, help="Whether it's a regression task.") flags.DEFINE_bool( "use_bert_format", default=False, help="Whether to use BERT format to arrange input data.") FLAGS = flags.FLAGS class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, label=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.label = label class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.io.gfile.GFile(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: # pylint: disable=g-explicit-length-test if len(line) == 0: continue lines.append(line) return lines class GLUEProcessor(DataProcessor): """GLUEProcessor.""" def __init__(self): self.train_file = "train.tsv" self.dev_file = "dev.tsv" self.test_file = "test.tsv" self.label_column = None self.text_a_column = None self.text_b_column = None self.contains_header = True self.test_text_a_column = None self.test_text_b_column = None self.test_contains_header = True def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, self.train_file)), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev") def get_test_examples(self, data_dir): """See base class.""" if self.test_text_a_column is None: self.test_text_a_column = self.text_a_column if self.test_text_b_column is None: self.test_text_b_column = self.text_b_column return self._create_examples( self._read_tsv(os.path.join(data_dir, self.test_file)), "test") def get_labels(self): """See base class.""" return ["0", "1"] def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0 and self.contains_header and set_type != "test": continue if i == 0 and self.test_contains_header and set_type == "test": continue guid = "%s-%s" % (set_type, i) a_column = ( self.text_a_column if set_type != "test" else self.test_text_a_column) b_column = ( self.text_b_column if set_type != "test" else self.test_text_b_column) # there are some incomplete lines in QNLI if len(line) <= a_column: logging.warning("Incomplete line, ignored.") continue text_a = line[a_column] if b_column is not None: if len(line) <= b_column: logging.warning("Incomplete line, ignored.") continue text_b = line[b_column] else: text_b = None if set_type == "test": label = self.get_labels()[0] else: if len(line) <= self.label_column: logging.warning("Incomplete line, ignored.") continue label = line[self.label_column] examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples class Yelp5Processor(DataProcessor): """Yelp5Processor.""" def get_train_examples(self, data_dir): return self._create_examples(os.path.join(data_dir, "train.csv")) def get_dev_examples(self, data_dir): return self._create_examples(os.path.join(data_dir, "test.csv")) def get_labels(self): """See base class.""" return ["1", "2", "3", "4", "5"] def _create_examples(self, input_file): """Creates examples for the training and dev sets.""" examples = [] with tf.io.gfile.GFile(input_file) as f: reader = csv.reader(f) for i, line in enumerate(reader): label = line[0] text_a = line[1].replace('""', '"').replace('\\"', '"') examples.append( InputExample(guid=str(i), text_a=text_a, text_b=None, label=label)) return examples class ImdbProcessor(DataProcessor): """ImdbProcessor.""" def get_labels(self): return ["neg", "pos"] def get_train_examples(self, data_dir): return self._create_examples(os.path.join(data_dir, "train")) def get_dev_examples(self, data_dir): return self._create_examples(os.path.join(data_dir, "test")) def _create_examples(self, data_dir): """Creates examples.""" examples = [] for label in ["neg", "pos"]: cur_dir = os.path.join(data_dir, label) for filename in tf.io.gfile.listdir(cur_dir): if not filename.endswith("txt"): continue if len(examples) % 1000 == 0: logging.info("Loading dev example %d", len(examples)) path = os.path.join(cur_dir, filename) with tf.io.gfile.GFile(path) as f: text = f.read().strip().replace("
", " ") examples.append( InputExample( guid="unused_id", text_a=text, text_b=None, label=label)) return examples class MnliMatchedProcessor(GLUEProcessor): """MnliMatchedProcessor.""" def __init__(self): super(MnliMatchedProcessor, self).__init__() self.dev_file = "dev_matched.tsv" self.test_file = "test_matched.tsv" self.label_column = -1 self.text_a_column = 8 self.text_b_column = 9 def get_labels(self): return ["contradiction", "entailment", "neutral"] class MnliMismatchedProcessor(MnliMatchedProcessor): def __init__(self): super(MnliMismatchedProcessor, self).__init__() self.dev_file = "dev_mismatched.tsv" self.test_file = "test_mismatched.tsv" class StsbProcessor(GLUEProcessor): """StsbProcessor.""" def __init__(self): super(StsbProcessor, self).__init__() self.label_column = 9 self.text_a_column = 7 self.text_b_column = 8 def get_labels(self): return [0.0] def _create_examples(self, lines, set_type): """Creates examples for the training and dev sets.""" examples = [] for (i, line) in enumerate(lines): if i == 0 and self.contains_header and set_type != "test": continue if i == 0 and self.test_contains_header and set_type == "test": continue guid = "%s-%s" % (set_type, i) a_column = ( self.text_a_column if set_type != "test" else self.test_text_a_column) b_column = ( self.text_b_column if set_type != "test" else self.test_text_b_column) # there are some incomplete lines in QNLI if len(line) <= a_column: logging.warning("Incomplete line, ignored.") continue text_a = line[a_column] if b_column is not None: if len(line) <= b_column: logging.warning("Incomplete line, ignored.") continue text_b = line[b_column] else: text_b = None if set_type == "test": label = self.get_labels()[0] else: if len(line) <= self.label_column: logging.warning("Incomplete line, ignored.") continue label = float(line[self.label_column]) examples.append( InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) return examples def file_based_convert_examples_to_features(examples, label_list, max_seq_length, tokenize_fn, output_file, num_passes=1): """Convert a set of `InputExample`s to a TFRecord file.""" # do not create duplicated records if tf.io.gfile.exists(output_file) and not FLAGS.overwrite_data: logging.info("Do not overwrite tfrecord %s exists.", output_file) return logging.info("Create new tfrecord %s.", output_file) writer = tf.io.TFRecordWriter(output_file) examples *= num_passes for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: logging.info("Writing example %d of %d", ex_index, len(examples)) feature = classifier_utils.convert_single_example(ex_index, example, label_list, max_seq_length, tokenize_fn, FLAGS.use_bert_format) def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f def create_float_feature(values): f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_float_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) if label_list is not None: features["label_ids"] = create_int_feature([feature.label_id]) else: features["label_ids"] = create_float_feature([float(feature.label_id)]) features["is_real_example"] = create_int_feature( [int(feature.is_real_example)]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString()) writer.close() def main(_): logging.set_verbosity(logging.INFO) processors = { "mnli_matched": MnliMatchedProcessor, "mnli_mismatched": MnliMismatchedProcessor, "sts-b": StsbProcessor, "imdb": ImdbProcessor, "yelp5": Yelp5Processor } task_name = FLAGS.task_name.lower() if task_name not in processors: raise ValueError("Task not found: %s" % (task_name)) processor = processors[task_name]() label_list = processor.get_labels() if not FLAGS.is_regression else None sp = spm.SentencePieceProcessor() sp.Load(FLAGS.spiece_model_file) def tokenize_fn(text): text = preprocess_utils.preprocess_text(text, lower=FLAGS.uncased) return preprocess_utils.encode_ids(sp, text) spm_basename = os.path.basename(FLAGS.spiece_model_file) train_file_base = "{}.len-{}.train.tf_record".format(spm_basename, FLAGS.max_seq_length) train_file = os.path.join(FLAGS.output_dir, train_file_base) logging.info("Use tfrecord file %s", train_file) train_examples = processor.get_train_examples(FLAGS.data_dir) np.random.shuffle(train_examples) logging.info("Num of train samples: %d", len(train_examples)) file_based_convert_examples_to_features(train_examples, label_list, FLAGS.max_seq_length, tokenize_fn, train_file, FLAGS.num_passes) if FLAGS.eval_split == "dev": eval_examples = processor.get_dev_examples(FLAGS.data_dir) else: eval_examples = processor.get_test_examples(FLAGS.data_dir) logging.info("Num of eval samples: %d", len(eval_examples)) # TPU requires a fixed batch size for all batches, therefore the number # of examples must be a multiple of the batch size, or else examples # will get dropped. So we pad with fake examples which are ignored # later on. These do NOT count towards the metric (all tf.metrics # support a per-instance weight, and these get a weight of 0.0). # # Modified in XL: We also adopt the same mechanism for GPUs. while len(eval_examples) % FLAGS.eval_batch_size != 0: eval_examples.append(classifier_utils.PaddingInputExample()) eval_file_base = "{}.len-{}.{}.eval.tf_record".format(spm_basename, FLAGS.max_seq_length, FLAGS.eval_split) eval_file = os.path.join(FLAGS.output_dir, eval_file_base) file_based_convert_examples_to_features(eval_examples, label_list, FLAGS.max_seq_length, tokenize_fn, eval_file) if __name__ == "__main__": app.run(main)