|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT library to process data for classification task.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
import csv |
|
import importlib |
|
import os |
|
|
|
from absl import logging |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
from official.nlp.bert import tokenization |
|
|
|
|
|
class InputExample(object): |
|
"""A single training/test example for simple sequence classification.""" |
|
|
|
def __init__(self, |
|
guid, |
|
text_a, |
|
text_b=None, |
|
label=None, |
|
weight=None, |
|
int_iden=None): |
|
"""Constructs a InputExample. |
|
|
|
Args: |
|
guid: Unique id for the example. |
|
text_a: string. The untokenized text of the first sequence. For single |
|
sequence tasks, only this sequence must be specified. |
|
text_b: (Optional) string. The untokenized text of the second sequence. |
|
Only must be specified for sequence pair tasks. |
|
label: (Optional) string. The label of the example. This should be |
|
specified for train and dev examples, but not for test examples. |
|
weight: (Optional) float. The weight of the example to be used during |
|
training. |
|
int_iden: (Optional) int. The int identification number of example in the |
|
corpus. |
|
""" |
|
self.guid = guid |
|
self.text_a = text_a |
|
self.text_b = text_b |
|
self.label = label |
|
self.weight = weight |
|
self.int_iden = int_iden |
|
|
|
|
|
class InputFeatures(object): |
|
"""A single set of features of data.""" |
|
|
|
def __init__(self, |
|
input_ids, |
|
input_mask, |
|
segment_ids, |
|
label_id, |
|
is_real_example=True, |
|
weight=None, |
|
int_iden=None): |
|
self.input_ids = input_ids |
|
self.input_mask = input_mask |
|
self.segment_ids = segment_ids |
|
self.label_id = label_id |
|
self.is_real_example = is_real_example |
|
self.weight = weight |
|
self.int_iden = int_iden |
|
|
|
|
|
class DataProcessor(object): |
|
"""Base class for data converters for sequence classification data sets.""" |
|
|
|
def __init__(self, process_text_fn=tokenization.convert_to_unicode): |
|
self.process_text_fn = process_text_fn |
|
|
|
def get_train_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for the train set.""" |
|
raise NotImplementedError() |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for the dev set.""" |
|
raise NotImplementedError() |
|
|
|
def get_test_examples(self, data_dir): |
|
"""Gets a collection of `InputExample`s for prediction.""" |
|
raise NotImplementedError() |
|
|
|
def get_labels(self): |
|
"""Gets the list of labels for this data set.""" |
|
raise NotImplementedError() |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""Gets the string identifier of the processor.""" |
|
raise NotImplementedError() |
|
|
|
@classmethod |
|
def _read_tsv(cls, input_file, quotechar=None): |
|
"""Reads a tab separated value file.""" |
|
with tf.io.gfile.GFile(input_file, "r") as f: |
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) |
|
lines = [] |
|
for line in reader: |
|
lines.append(line) |
|
return lines |
|
|
|
|
|
class XnliProcessor(DataProcessor): |
|
"""Processor for the XNLI data set.""" |
|
supported_languages = [ |
|
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", |
|
"ur", "vi", "zh" |
|
] |
|
|
|
def __init__(self, |
|
language="en", |
|
process_text_fn=tokenization.convert_to_unicode): |
|
super(XnliProcessor, self).__init__(process_text_fn) |
|
if language == "all": |
|
self.languages = XnliProcessor.supported_languages |
|
elif language not in XnliProcessor.supported_languages: |
|
raise ValueError("language %s is not supported for XNLI task." % language) |
|
else: |
|
self.languages = [language] |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = [] |
|
for language in self.languages: |
|
|
|
lines.extend( |
|
self._read_tsv( |
|
os.path.join(data_dir, "multinli", |
|
"multinli.train.%s.tsv" % language))[1:]) |
|
|
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "train-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
if label == self.process_text_fn("contradictory"): |
|
label = self.process_text_fn("contradiction") |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv")) |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "dev-%d" % i |
|
text_a = self.process_text_fn(line[6]) |
|
text_b = self.process_text_fn(line[7]) |
|
label = self.process_text_fn(line[1]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv")) |
|
examples_by_lang = {k: [] for k in XnliProcessor.supported_languages} |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "test-%d" % i |
|
language = self.process_text_fn(line[0]) |
|
text_a = self.process_text_fn(line[6]) |
|
text_b = self.process_text_fn(line[7]) |
|
label = self.process_text_fn(line[1]) |
|
examples_by_lang[language].append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples_by_lang |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["contradiction", "entailment", "neutral"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "XNLI" |
|
|
|
|
|
class XtremeXnliProcessor(DataProcessor): |
|
"""Processor for the XTREME XNLI data set.""" |
|
supported_languages = [ |
|
"ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr", |
|
"ur", "vi", "zh" |
|
] |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) |
|
|
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "train-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "dev-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
examples_by_lang = {k: [] for k in self.supported_languages} |
|
for lang in self.supported_languages: |
|
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) |
|
for (i, line) in enumerate(lines): |
|
guid = f"test-{i}" |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = "contradiction" |
|
examples_by_lang[lang].append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples_by_lang |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["contradiction", "entailment", "neutral"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "XTREME-XNLI" |
|
|
|
|
|
class PawsxProcessor(DataProcessor): |
|
"""Processor for the PAWS-X data set.""" |
|
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"] |
|
|
|
def __init__(self, |
|
language="en", |
|
process_text_fn=tokenization.convert_to_unicode): |
|
super(PawsxProcessor, self).__init__(process_text_fn) |
|
if language == "all": |
|
self.languages = PawsxProcessor.supported_languages |
|
elif language not in PawsxProcessor.supported_languages: |
|
raise ValueError("language %s is not supported for PAWS-X task." % |
|
language) |
|
else: |
|
self.languages = [language] |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = [] |
|
for language in self.languages: |
|
if language == "en": |
|
train_tsv = "train.tsv" |
|
else: |
|
train_tsv = "translated_train.tsv" |
|
|
|
lines.extend( |
|
self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:]) |
|
|
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "train-%d" % i |
|
text_a = self.process_text_fn(line[1]) |
|
text_b = self.process_text_fn(line[2]) |
|
label = self.process_text_fn(line[3]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = [] |
|
for lang in PawsxProcessor.supported_languages: |
|
lines.extend(self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))) |
|
|
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "dev-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
examples_by_lang = {k: [] for k in self.supported_languages} |
|
for lang in self.supported_languages: |
|
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) |
|
for (i, line) in enumerate(lines): |
|
guid = "test-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
examples_by_lang[lang].append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples_by_lang |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "XTREME-PAWS-X" |
|
|
|
|
|
class XtremePawsxProcessor(DataProcessor): |
|
"""Processor for the XTREME PAWS-X data set.""" |
|
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"] |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv")) |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "train-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv")) |
|
|
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
guid = "dev-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = self.process_text_fn(line[2]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
examples_by_lang = {k: [] for k in self.supported_languages} |
|
for lang in self.supported_languages: |
|
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv")) |
|
for (i, line) in enumerate(lines): |
|
guid = "test-%d" % i |
|
text_a = self.process_text_fn(line[0]) |
|
text_b = self.process_text_fn(line[1]) |
|
label = "0" |
|
examples_by_lang[lang].append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples_by_lang |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "XTREME-PAWS-X" |
|
|
|
|
|
class MnliProcessor(DataProcessor): |
|
"""Processor for the MultiNLI data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), |
|
"dev_matched") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["contradiction", "entailment", "neutral"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "MNLI" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, self.process_text_fn(line[0])) |
|
text_a = self.process_text_fn(line[8]) |
|
text_b = self.process_text_fn(line[9]) |
|
if set_type == "test": |
|
label = "contradiction" |
|
else: |
|
label = self.process_text_fn(line[-1]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
|
|
class MrpcProcessor(DataProcessor): |
|
"""Processor for the MRPC data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "MRPC" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, i) |
|
text_a = self.process_text_fn(line[3]) |
|
text_b = self.process_text_fn(line[4]) |
|
if set_type == "test": |
|
label = "0" |
|
else: |
|
label = self.process_text_fn(line[0]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
|
|
class QqpProcessor(DataProcessor): |
|
"""Processor for the QQP data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "QQP" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, line[0]) |
|
try: |
|
text_a = line[3] |
|
text_b = line[4] |
|
label = line[5] |
|
except IndexError: |
|
continue |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
|
|
class ColaProcessor(DataProcessor): |
|
"""Processor for the CoLA data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "COLA" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
|
|
if set_type == "test" and i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, i) |
|
if set_type == "test": |
|
text_a = self.process_text_fn(line[1]) |
|
label = "0" |
|
else: |
|
text_a = self.process_text_fn(line[3]) |
|
label = self.process_text_fn(line[1]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) |
|
return examples |
|
|
|
|
|
class RteProcessor(DataProcessor): |
|
"""Processor for the RTE data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
|
|
|
|
return ["entailment", "not_entailment"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "RTE" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for i, line in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, i) |
|
if set_type == "test": |
|
text_a = tokenization.convert_to_unicode(line[1]) |
|
text_b = tokenization.convert_to_unicode(line[2]) |
|
label = "entailment" |
|
else: |
|
text_a = tokenization.convert_to_unicode(line[1]) |
|
text_b = tokenization.convert_to_unicode(line[2]) |
|
label = tokenization.convert_to_unicode(line[3]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
|
|
class SstProcessor(DataProcessor): |
|
"""Processor for the SST-2 data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["0", "1"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "SST-2" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, i) |
|
if set_type == "test": |
|
text_a = tokenization.convert_to_unicode(line[1]) |
|
label = "0" |
|
else: |
|
text_a = tokenization.convert_to_unicode(line[0]) |
|
label = tokenization.convert_to_unicode(line[1]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) |
|
return examples |
|
|
|
|
|
class QnliProcessor(DataProcessor): |
|
"""Processor for the QNLI data set (GLUE version).""" |
|
|
|
def get_train_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev_matched") |
|
|
|
def get_test_examples(self, data_dir): |
|
"""See base class.""" |
|
return self._create_examples( |
|
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test") |
|
|
|
def get_labels(self): |
|
"""See base class.""" |
|
return ["entailment", "not_entailment"] |
|
|
|
@staticmethod |
|
def get_processor_name(): |
|
"""See base class.""" |
|
return "QNLI" |
|
|
|
def _create_examples(self, lines, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
examples = [] |
|
for (i, line) in enumerate(lines): |
|
if i == 0: |
|
continue |
|
guid = "%s-%s" % (set_type, 1) |
|
if set_type == "test": |
|
text_a = tokenization.convert_to_unicode(line[1]) |
|
text_b = tokenization.convert_to_unicode(line[2]) |
|
label = "entailment" |
|
else: |
|
text_a = tokenization.convert_to_unicode(line[1]) |
|
text_b = tokenization.convert_to_unicode(line[2]) |
|
label = tokenization.convert_to_unicode(line[-1]) |
|
examples.append( |
|
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) |
|
return examples |
|
|
|
|
|
class TfdsProcessor(DataProcessor): |
|
"""Processor for generic text classification and regression TFDS data set. |
|
|
|
The TFDS parameters are expected to be provided in the tfds_params string, in |
|
a comma-separated list of parameter assignments. |
|
Examples: |
|
tfds_params="dataset=scicite,text_key=string" |
|
tfds_params="dataset=imdb_reviews,test_split=,dev_split=test" |
|
tfds_params="dataset=glue/cola,text_key=sentence" |
|
tfds_params="dataset=glue/sst2,text_key=sentence" |
|
tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence" |
|
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2" |
|
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2," |
|
"is_regression=true,label_type=float" |
|
Possible parameters (please refer to the documentation of Tensorflow Datasets |
|
(TFDS) for the meaning of individual parameters): |
|
dataset: Required dataset name (potentially with subset and version number). |
|
data_dir: Optional TFDS source root directory. |
|
module_import: Optional Dataset module to import. |
|
train_split: Name of the train split (defaults to `train`). |
|
dev_split: Name of the dev split (defaults to `validation`). |
|
test_split: Name of the test split (defaults to `test`). |
|
text_key: Key of the text_a feature (defaults to `text`). |
|
text_b_key: Key of the second text feature if available. |
|
label_key: Key of the label feature (defaults to `label`). |
|
test_text_key: Key of the text feature to use in test set. |
|
test_text_b_key: Key of the second text feature to use in test set. |
|
test_label: String to be used as the label for all test examples. |
|
label_type: Type of the label key (defaults to `int`). |
|
weight_key: Key of the float sample weight (is not used if not provided). |
|
is_regression: Whether the task is a regression problem (defaults to False). |
|
""" |
|
|
|
def __init__(self, |
|
tfds_params, |
|
process_text_fn=tokenization.convert_to_unicode): |
|
super(TfdsProcessor, self).__init__(process_text_fn) |
|
self._process_tfds_params_str(tfds_params) |
|
if self.module_import: |
|
importlib.import_module(self.module_import) |
|
|
|
self.dataset, info = tfds.load( |
|
self.dataset_name, data_dir=self.data_dir, with_info=True) |
|
if self.is_regression: |
|
self._labels = None |
|
else: |
|
self._labels = list(range(info.features[self.label_key].num_classes)) |
|
|
|
def _process_tfds_params_str(self, params_str): |
|
"""Extracts TFDS parameters from a comma-separated assignements string.""" |
|
dtype_map = {"int": int, "float": float} |
|
cast_str_to_bool = lambda s: s.lower() not in ["false", "0"] |
|
|
|
tuples = [x.split("=") for x in params_str.split(",")] |
|
d = {k.strip(): v.strip() for k, v in tuples} |
|
self.dataset_name = d["dataset"] |
|
self.data_dir = d.get("data_dir", None) |
|
self.module_import = d.get("module_import", None) |
|
self.train_split = d.get("train_split", "train") |
|
self.dev_split = d.get("dev_split", "validation") |
|
self.test_split = d.get("test_split", "test") |
|
self.text_key = d.get("text_key", "text") |
|
self.text_b_key = d.get("text_b_key", None) |
|
self.label_key = d.get("label_key", "label") |
|
self.test_text_key = d.get("test_text_key", self.text_key) |
|
self.test_text_b_key = d.get("test_text_b_key", self.text_b_key) |
|
self.test_label = d.get("test_label", "test_example") |
|
self.label_type = dtype_map[d.get("label_type", "int")] |
|
self.is_regression = cast_str_to_bool(d.get("is_regression", "False")) |
|
self.weight_key = d.get("weight_key", None) |
|
|
|
def get_train_examples(self, data_dir): |
|
assert data_dir is None |
|
return self._create_examples(self.train_split, "train") |
|
|
|
def get_dev_examples(self, data_dir): |
|
assert data_dir is None |
|
return self._create_examples(self.dev_split, "dev") |
|
|
|
def get_test_examples(self, data_dir): |
|
assert data_dir is None |
|
return self._create_examples(self.test_split, "test") |
|
|
|
def get_labels(self): |
|
return self._labels |
|
|
|
def get_processor_name(self): |
|
return "TFDS_" + self.dataset_name |
|
|
|
def _create_examples(self, split_name, set_type): |
|
"""Creates examples for the training and dev sets.""" |
|
if split_name not in self.dataset: |
|
raise ValueError("Split {} not available.".format(split_name)) |
|
dataset = self.dataset[split_name].as_numpy_iterator() |
|
examples = [] |
|
text_b, weight = None, None |
|
for i, example in enumerate(dataset): |
|
guid = "%s-%s" % (set_type, i) |
|
if set_type == "test": |
|
text_a = self.process_text_fn(example[self.test_text_key]) |
|
if self.test_text_b_key: |
|
text_b = self.process_text_fn(example[self.test_text_b_key]) |
|
label = self.test_label |
|
else: |
|
text_a = self.process_text_fn(example[self.text_key]) |
|
if self.text_b_key: |
|
text_b = self.process_text_fn(example[self.text_b_key]) |
|
label = self.label_type(example[self.label_key]) |
|
if self.weight_key: |
|
weight = float(example[self.weight_key]) |
|
examples.append( |
|
InputExample( |
|
guid=guid, |
|
text_a=text_a, |
|
text_b=text_b, |
|
label=label, |
|
weight=weight)) |
|
return examples |
|
|
|
|
|
def convert_single_example(ex_index, example, label_list, max_seq_length, |
|
tokenizer): |
|
"""Converts a single `InputExample` into a single `InputFeatures`.""" |
|
label_map = {} |
|
if label_list: |
|
for (i, label) in enumerate(label_list): |
|
label_map[label] = i |
|
|
|
tokens_a = tokenizer.tokenize(example.text_a) |
|
tokens_b = None |
|
if example.text_b: |
|
tokens_b = tokenizer.tokenize(example.text_b) |
|
|
|
if tokens_b: |
|
|
|
|
|
|
|
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) |
|
else: |
|
|
|
if len(tokens_a) > max_seq_length - 2: |
|
tokens_a = tokens_a[0:(max_seq_length - 2)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokens = [] |
|
segment_ids = [] |
|
tokens.append("[CLS]") |
|
segment_ids.append(0) |
|
for token in tokens_a: |
|
tokens.append(token) |
|
segment_ids.append(0) |
|
tokens.append("[SEP]") |
|
segment_ids.append(0) |
|
|
|
if tokens_b: |
|
for token in tokens_b: |
|
tokens.append(token) |
|
segment_ids.append(1) |
|
tokens.append("[SEP]") |
|
segment_ids.append(1) |
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
|
|
|
|
input_mask = [1] * len(input_ids) |
|
|
|
|
|
while len(input_ids) < max_seq_length: |
|
input_ids.append(0) |
|
input_mask.append(0) |
|
segment_ids.append(0) |
|
|
|
assert len(input_ids) == max_seq_length |
|
assert len(input_mask) == max_seq_length |
|
assert len(segment_ids) == max_seq_length |
|
|
|
label_id = label_map[example.label] if label_map else example.label |
|
if ex_index < 5: |
|
logging.info("*** Example ***") |
|
logging.info("guid: %s", (example.guid)) |
|
logging.info("tokens: %s", |
|
" ".join([tokenization.printable_text(x) for x in tokens])) |
|
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) |
|
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) |
|
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) |
|
logging.info("label: %s (id = %s)", example.label, str(label_id)) |
|
logging.info("weight: %s", example.weight) |
|
logging.info("int_iden: %s", str(example.int_iden)) |
|
|
|
feature = InputFeatures( |
|
input_ids=input_ids, |
|
input_mask=input_mask, |
|
segment_ids=segment_ids, |
|
label_id=label_id, |
|
is_real_example=True, |
|
weight=example.weight, |
|
int_iden=example.int_iden) |
|
|
|
return feature |
|
|
|
|
|
def file_based_convert_examples_to_features(examples, |
|
label_list, |
|
max_seq_length, |
|
tokenizer, |
|
output_file, |
|
label_type=None): |
|
"""Convert a set of `InputExample`s to a TFRecord file.""" |
|
|
|
tf.io.gfile.makedirs(os.path.dirname(output_file)) |
|
writer = tf.io.TFRecordWriter(output_file) |
|
|
|
for (ex_index, example) in enumerate(examples): |
|
if ex_index % 10000 == 0: |
|
logging.info("Writing example %d of %d", ex_index, len(examples)) |
|
|
|
feature = convert_single_example(ex_index, example, label_list, |
|
max_seq_length, tokenizer) |
|
|
|
def create_int_feature(values): |
|
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) |
|
return f |
|
|
|
def create_float_feature(values): |
|
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) |
|
return f |
|
|
|
features = collections.OrderedDict() |
|
features["input_ids"] = create_int_feature(feature.input_ids) |
|
features["input_mask"] = create_int_feature(feature.input_mask) |
|
features["segment_ids"] = create_int_feature(feature.segment_ids) |
|
if label_type is not None and label_type == float: |
|
features["label_ids"] = create_float_feature([feature.label_id]) |
|
elif feature.label_id is not None: |
|
features["label_ids"] = create_int_feature([feature.label_id]) |
|
features["is_real_example"] = create_int_feature( |
|
[int(feature.is_real_example)]) |
|
if feature.weight is not None: |
|
features["weight"] = create_float_feature([feature.weight]) |
|
if feature.int_iden is not None: |
|
features["int_iden"] = create_int_feature([feature.int_iden]) |
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) |
|
writer.write(tf_example.SerializeToString()) |
|
writer.close() |
|
|
|
|
|
def _truncate_seq_pair(tokens_a, tokens_b, max_length): |
|
"""Truncates a sequence pair in place to the maximum length.""" |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
total_length = len(tokens_a) + len(tokens_b) |
|
if total_length <= max_length: |
|
break |
|
if len(tokens_a) > len(tokens_b): |
|
tokens_a.pop() |
|
else: |
|
tokens_b.pop() |
|
|
|
|
|
def generate_tf_record_from_data_file(processor, |
|
data_dir, |
|
tokenizer, |
|
train_data_output_path=None, |
|
eval_data_output_path=None, |
|
test_data_output_path=None, |
|
max_seq_length=128): |
|
"""Generates and saves training data into a tf record file. |
|
|
|
Arguments: |
|
processor: Input processor object to be used for generating data. Subclass |
|
of `DataProcessor`. |
|
data_dir: Directory that contains train/eval data to process. Data files |
|
should be in from "dev.tsv", "test.tsv", or "train.tsv". |
|
tokenizer: The tokenizer to be applied on the data. |
|
train_data_output_path: Output to which processed tf record for training |
|
will be saved. |
|
eval_data_output_path: Output to which processed tf record for evaluation |
|
will be saved. |
|
test_data_output_path: Output to which processed tf record for testing |
|
will be saved. Must be a pattern template with {} if processor has |
|
language specific test data. |
|
max_seq_length: Maximum sequence length of the to be generated |
|
training/eval data. |
|
|
|
Returns: |
|
A dictionary containing input meta data. |
|
""" |
|
assert train_data_output_path or eval_data_output_path |
|
|
|
label_list = processor.get_labels() |
|
label_type = getattr(processor, "label_type", None) |
|
is_regression = getattr(processor, "is_regression", False) |
|
has_sample_weights = getattr(processor, "weight_key", False) |
|
assert train_data_output_path |
|
|
|
train_input_data_examples = processor.get_train_examples(data_dir) |
|
file_based_convert_examples_to_features(train_input_data_examples, label_list, |
|
max_seq_length, tokenizer, |
|
train_data_output_path, label_type) |
|
num_training_data = len(train_input_data_examples) |
|
|
|
if eval_data_output_path: |
|
eval_input_data_examples = processor.get_dev_examples(data_dir) |
|
file_based_convert_examples_to_features(eval_input_data_examples, |
|
label_list, max_seq_length, |
|
tokenizer, eval_data_output_path, |
|
label_type) |
|
|
|
if test_data_output_path: |
|
test_input_data_examples = processor.get_test_examples(data_dir) |
|
if isinstance(test_input_data_examples, dict): |
|
for language, examples in test_input_data_examples.items(): |
|
file_based_convert_examples_to_features( |
|
examples, label_list, max_seq_length, tokenizer, |
|
test_data_output_path.format(language), label_type) |
|
else: |
|
file_based_convert_examples_to_features(test_input_data_examples, |
|
label_list, max_seq_length, |
|
tokenizer, test_data_output_path, |
|
label_type) |
|
|
|
meta_data = { |
|
"processor_type": processor.get_processor_name(), |
|
"train_data_size": num_training_data, |
|
"max_seq_length": max_seq_length, |
|
} |
|
if is_regression: |
|
meta_data["task_type"] = "bert_regression" |
|
meta_data["label_type"] = {int: "int", float: "float"}[label_type] |
|
else: |
|
meta_data["task_type"] = "bert_classification" |
|
meta_data["num_labels"] = len(processor.get_labels()) |
|
if has_sample_weights: |
|
meta_data["has_sample_weights"] = True |
|
|
|
if eval_data_output_path: |
|
meta_data["eval_data_size"] = len(eval_input_data_examples) |
|
|
|
if test_data_output_path: |
|
test_input_data_examples = processor.get_test_examples(data_dir) |
|
if isinstance(test_input_data_examples, dict): |
|
for language, examples in test_input_data_examples.items(): |
|
meta_data["test_{}_data_size".format(language)] = len(examples) |
|
else: |
|
meta_data["test_data_size"] = len(test_input_data_examples) |
|
|
|
return meta_data |
|
|