|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Script to pre-process classification data into tfrecords.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
import csv |
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
from absl import logging |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
import sentencepiece as spm |
|
from official.nlp.xlnet import classifier_utils |
|
from official.nlp.xlnet import preprocess_utils |
|
|
|
|
|
flags.DEFINE_bool( |
|
"overwrite_data", |
|
default=False, |
|
help="If False, will use cached data if available.") |
|
flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.") |
|
flags.DEFINE_string( |
|
"spiece_model_file", default="", help="Sentence Piece model path.") |
|
flags.DEFINE_string("data_dir", default="", help="Directory for input data.") |
|
|
|
|
|
flags.DEFINE_string("eval_split", default="dev", help="could be dev or test") |
|
flags.DEFINE_string("task_name", default=None, help="Task name") |
|
flags.DEFINE_integer( |
|
"eval_batch_size", default=64, help="batch size for evaluation") |
|
flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length") |
|
flags.DEFINE_integer( |
|
"num_passes", |
|
default=1, |
|
help="Num passes for processing training data. " |
|
"This is use to batch data without loss for TPUs.") |
|
flags.DEFINE_bool("uncased", default=False, help="Use uncased.") |
|
flags.DEFINE_bool( |
|
"is_regression", default=False, help="Whether it's a regression task.") |
|
flags.DEFINE_bool( |
|
"use_bert_format", |
|
default=False, |
|
help="Whether to use BERT format to arrange input data.") |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
class InputExample(object): |
|
"""A single training/test example for simple sequence classification.""" |
|
|
|
def __init__(self, guid, text_a, text_b=None, label=None): |
|
"""Constructs a InputExample. |
|
|
|
Args: |
|
guid: Unique id for the example. |
|
text_a: string. The untokenized text of the first sequence. For single |
|
sequence tasks, only this sequence must be specified. |
|
text_b: (Optional) string. The untokenized text of the second sequence. |
|
Only must be specified for sequence pair tasks. |
|
label: (Optional) string. The label of the example. This should be |
|
specified for train and dev examples, but not for test examples. |
|
""" |
|
self.guid = guid |
|
self.text_a = text_a |
|
self.text_b = text_b |
|
self.label = label |
|
|
|
|
|
class DataProcessor(object): |
|
"""Base class for data converters for sequence classification data sets.""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for the train set.""" |
|
raise NotImplementedError() |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for the dev set.""" |
|
raise NotImplementedError() |
|
|
|
def get_test_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for prediction.""" |
|
raise NotImplementedError() |
|
|
|
def get_labels(self): |
|
"""Gets the list of labels for this data set.""" |
|
raise NotImplementedError() |
|
|
|
@classmethod |
|
def _read_tsv(cls, input_file, quotechar=None): |
|
"""Reads a tab separated value file.""" |
|
with tf.io.gfile.GFile(input_file, "r") as f: |
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) |
|
lines = [] |
|
for line in reader: |
|
|
|
if len(line) == 0: |
|
continue |
|
lines.append(line) |
|
return lines |
|
|
|
|
|
class GLUEProcessor(DataProcessor): |
|
"""GLUEProcessor.""" |
|
|
|
def __init__(self): |
|
self.train_file = "train.tsv" |
|
self.dev_file = "dev.tsv" |
|
self.test_file = "test.tsv" |
|
self.label_column = None |
|
self.text_a_column = None |
|
self.text_b_column = None |
|
self.contains_header = True |
|
self.test_text_a_column = None |
|
self.test_text_b_column = None |
|
self.test_contains_header = True |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, self.train_file)), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
if self.test_text_a_column is None: |
|
self.test_text_a_column = self.text_a_column |
|
if self.test_text_b_column is None: |
|
self.test_text_b_column = self.text_b_column |
|
|
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, self.test_file)), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0 and self.contains_header and set_type != "test": |
|
continue |
|
if i == 0 and self.test_contains_header and set_type == "test": |
|
continue |
|
guid = "%s-%s" % (set_type, i) |
|
|
|
a_column = ( |
|
self.text_a_column if set_type != "test" else self.test_text_a_column) |
|
b_column = ( |
|
self.text_b_column if set_type != "test" else self.test_text_b_column) |
|
|
|
|
|
if len(line) <= a_column: |
|
logging.warning("Incomplete line, ignored.") |
|
continue |
|
text_a = line[a_column] |
|
|
|
if b_column is not None: |
|
if len(line) <= b_column: |
|
logging.warning("Incomplete line, ignored.") |
|
continue |
|
text_b = line[b_column] |
|
else: |
|
text_b = None |
|
|
|
if set_type == "test": |
|
label = self.get_labels()[0] |
|
else: |
|
if len(line) <= self.label_column: |
|
logging.warning("Incomplete line, ignored.") |
|
continue |
|
label = line[self.label_column] |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
|
|
class Yelp5Processor(DataProcessor): |
|
"""Yelp5Processor.""" |
|
|
|
def get_train_examples(self, data_dir): |
|
return self._create_examples(os.path.join(data_dir, "train.csv")) |
|
|
|
def get_dev_examples(self, data_dir): |
|
return self._create_examples(os.path.join(data_dir, "test.csv")) |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["1", "2", "3", "4", "5"] |
|
|
|
def _create_examples(self, input_file): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
with tf.io.gfile.GFile(input_file) as f: |
|
reader = csv.reader(f) |
|
for i, line in enumerate(reader): |
|
|
|
label = line[0] |
|
text_a = line[1].replace('""', '"').replace('\\"', '"') |
|
examples.append( |
|
InputExample(guid=str(i), text_a=text_a, text_b=None, label=label)) |
|
return examples |
|
|
|
|
|
class ImdbProcessor(DataProcessor): |
|
"""ImdbProcessor.""" |
|
|
|
def get_labels(self): |
|
return ["neg", "pos"] |
|
|
|
def get_train_examples(self, data_dir): |
|
return self._create_examples(os.path.join(data_dir, "train")) |
|
|
|
def get_dev_examples(self, data_dir): |
|
return self._create_examples(os.path.join(data_dir, "test")) |
|
|
|
def _create_examples(self, data_dir): |
|
"""Creates examples.""" |
|
examples = [] |
|
for label in ["neg", "pos"]: |
|
cur_dir = os.path.join(data_dir, label) |
|
for filename in tf.io.gfile.listdir(cur_dir): |
|
if not filename.endswith("txt"): |
|
continue |
|
|
|
if len(examples) % 1000 == 0: |
|
logging.info("Loading dev example %d", len(examples)) |
|
|
|
path = os.path.join(cur_dir, filename) |
|
with tf.io.gfile.GFile(path) as f: |
|
text = f.read().strip().replace("<br />", " ") |
|
examples.append( |
|
InputExample( |
|
guid="unused_id", text_a=text, text_b=None, label=label)) |
|
return examples |
|
|
|
|
|
class MnliMatchedProcessor(GLUEProcessor): |
|
"""MnliMatchedProcessor.""" |
|
|
|
def __init__(self): |
|
super(MnliMatchedProcessor, self).__init__() |
|
self.dev_file = "dev_matched.tsv" |
|
self.test_file = "test_matched.tsv" |
|
self.label_column = -1 |
|
self.text_a_column = 8 |
|
self.text_b_column = 9 |
|
|
|
def get_labels(self): |
|
return ["contradiction", "entailment", "neutral"] |
|
|
|
|
|
class MnliMismatchedProcessor(MnliMatchedProcessor): |
|
|
|
def __init__(self): |
|
super(MnliMismatchedProcessor, self).__init__() |
|
self.dev_file = "dev_mismatched.tsv" |
|
self.test_file = "test_mismatched.tsv" |
|
|
|
|
|
class StsbProcessor(GLUEProcessor): |
|
"""StsbProcessor.""" |
|
|
|
def __init__(self): |
|
super(StsbProcessor, self).__init__() |
|
self.label_column = 9 |
|
self.text_a_column = 7 |
|
self.text_b_column = 8 |
|
|
|
def get_labels(self): |
|
return [0.0] |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0 and self.contains_header and set_type != "test": |
|
continue |
|
if i == 0 and self.test_contains_header and set_type == "test": |
|
continue |
|
guid = "%s-%s" % (set_type, i) |
|
|
|
a_column = ( |
|
self.text_a_column if set_type != "test" else self.test_text_a_column) |
|
b_column = ( |
|
self.text_b_column if set_type != "test" else self.test_text_b_column) |
|
|
|
|
|
if len(line) <= a_column: |
|
logging.warning("Incomplete line, ignored.") |
|
continue |
|
text_a = line[a_column] |
|
|
|
if b_column is not None: |
|
if len(line) <= b_column: |
|
logging.warning("Incomplete line, ignored.") |
|
continue |
|
text_b = line[b_column] |
|
else: |
|
text_b = None |
|
|
|
if set_type == "test": |
|
label = self.get_labels()[0] |
|
else: |
|
if len(line) <= self.label_column: |
|
logging.warning("Incomplete line, ignored.") |
|
continue |
|
label = float(line[self.label_column]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
|
|
return examples |
|
|
|
|
|
def file_based_convert_examples_to_features(examples, |
|
label_list, |
|
max_seq_length, |
|
tokenize_fn, |
|
output_file, |
|
num_passes=1): |
|
"""Convert a set of `InputExample`s to a TFRecord file.""" |
|
|
|
|
|
if tf.io.gfile.exists(output_file) and not FLAGS.overwrite_data: |
|
logging.info("Do not overwrite tfrecord %s exists.", output_file) |
|
return |
|
|
|
logging.info("Create new tfrecord %s.", output_file) |
|
|
|
writer = tf.io.TFRecordWriter(output_file) |
|
|
|
examples *= num_passes |
|
|
|
for (ex_index, example) in enumerate(examples): |
|
if ex_index % 10000 == 0: |
|
logging.info("Writing example %d of %d", ex_index, len(examples)) |
|
|
|
feature = classifier_utils.convert_single_example(ex_index, example, |
|
label_list, |
|
max_seq_length, |
|
tokenize_fn, |
|
FLAGS.use_bert_format) |
|
|
|
def create_int_feature(values): |
|
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) |
|
return f |
|
|
|
def create_float_feature(values): |
|
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) |
|
return f |
|
|
|
features = collections.OrderedDict() |
|
features["input_ids"] = create_int_feature(feature.input_ids) |
|
features["input_mask"] = create_float_feature(feature.input_mask) |
|
features["segment_ids"] = create_int_feature(feature.segment_ids) |
|
if label_list is not None: |
|
features["label_ids"] = create_int_feature([feature.label_id]) |
|
else: |
|
features["label_ids"] = create_float_feature([float(feature.label_id)]) |
|
features["is_real_example"] = create_int_feature( |
|
[int(feature.is_real_example)]) |
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) |
|
writer.write(tf_example.SerializeToString()) |
|
writer.close() |
|
|
|
|
|
def main(_): |
|
logging.set_verbosity(logging.INFO) |
|
processors = { |
|
"mnli_matched": MnliMatchedProcessor, |
|
"mnli_mismatched": MnliMismatchedProcessor, |
|
"sts-b": StsbProcessor, |
|
"imdb": ImdbProcessor, |
|
"yelp5": Yelp5Processor |
|
} |
|
|
|
task_name = FLAGS.task_name.lower() |
|
|
|
if task_name not in processors: |
|
raise ValueError("Task not found: %s" % (task_name)) |
|
|
|
processor = processors[task_name]() |
|
label_list = processor.get_labels() if not FLAGS.is_regression else None |
|
|
|
sp = spm.SentencePieceProcessor() |
|
sp.Load(FLAGS.spiece_model_file) |
|
|
|
def tokenize_fn(text): |
|
text = preprocess_utils.preprocess_text(text, lower=FLAGS.uncased) |
|
return preprocess_utils.encode_ids(sp, text) |
|
|
|
spm_basename = os.path.basename(FLAGS.spiece_model_file) |
|
|
|
train_file_base = "{}.len-{}.train.tf_record".format(spm_basename, |
|
FLAGS.max_seq_length) |
|
train_file = os.path.join(FLAGS.output_dir, train_file_base) |
|
logging.info("Use tfrecord file %s", train_file) |
|
|
|
train_examples = processor.get_train_examples(FLAGS.data_dir) |
|
np.random.shuffle(train_examples) |
|
logging.info("Num of train samples: %d", len(train_examples)) |
|
|
|
file_based_convert_examples_to_features(train_examples, label_list, |
|
FLAGS.max_seq_length, tokenize_fn, |
|
train_file, FLAGS.num_passes) |
|
if FLAGS.eval_split == "dev": |
|
eval_examples = processor.get_dev_examples(FLAGS.data_dir) |
|
else: |
|
eval_examples = processor.get_test_examples(FLAGS.data_dir) |
|
|
|
logging.info("Num of eval samples: %d", len(eval_examples)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while len(eval_examples) % FLAGS.eval_batch_size != 0: |
|
eval_examples.append(classifier_utils.PaddingInputExample()) |
|
|
|
eval_file_base = "{}.len-{}.{}.eval.tf_record".format(spm_basename, |
|
FLAGS.max_seq_length, |
|
FLAGS.eval_split) |
|
eval_file = os.path.join(FLAGS.output_dir, eval_file_base) |
|
|
|
file_based_convert_examples_to_features(eval_examples, label_list, |
|
FLAGS.max_seq_length, tokenize_fn, |
|
eval_file) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|