Pradeep Kumar commited on
Commit
d6fca6a
1 Parent(s): 17cab64

Upload classifier_data_lib.py

Browse files
Files changed (1) hide show
  1. classifier_data_lib.py +1612 -0
classifier_data_lib.py ADDED
@@ -0,0 +1,1612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BERT library to process data for classification task."""
16
+
17
+ import collections
18
+ import csv
19
+ import importlib
20
+ import json
21
+ import os
22
+
23
+ from absl import logging
24
+ import tensorflow as tf, tf_keras
25
+ import tensorflow_datasets as tfds
26
+
27
+ from official.nlp.tools import tokenization
28
+
29
+
30
+ class InputExample(object):
31
+ """A single training/test example for simple seq regression/classification."""
32
+
33
+ def __init__(self,
34
+ guid,
35
+ text_a,
36
+ text_b=None,
37
+ label=None,
38
+ weight=None,
39
+ example_id=None):
40
+ """Constructs a InputExample.
41
+
42
+ Args:
43
+ guid: Unique id for the example.
44
+ text_a: string. The untokenized text of the first sequence. For single
45
+ sequence tasks, only this sequence must be specified.
46
+ text_b: (Optional) string. The untokenized text of the second sequence.
47
+ Only must be specified for sequence pair tasks.
48
+ label: (Optional) string for classification, float for regression. The
49
+ label of the example. This should be specified for train and dev
50
+ examples, but not for test examples.
51
+ weight: (Optional) float. The weight of the example to be used during
52
+ training.
53
+ example_id: (Optional) int. The int identification number of example in
54
+ the corpus.
55
+ """
56
+ self.guid = guid
57
+ self.text_a = text_a
58
+ self.text_b = text_b
59
+ self.label = label
60
+ self.weight = weight
61
+ self.example_id = example_id
62
+
63
+
64
+ class InputFeatures(object):
65
+ """A single set of features of data."""
66
+
67
+ def __init__(self,
68
+ input_ids,
69
+ input_mask,
70
+ segment_ids,
71
+ label_id,
72
+ is_real_example=True,
73
+ weight=None,
74
+ example_id=None):
75
+ self.input_ids = input_ids
76
+ self.input_mask = input_mask
77
+ self.segment_ids = segment_ids
78
+ self.label_id = label_id
79
+ self.is_real_example = is_real_example
80
+ self.weight = weight
81
+ self.example_id = example_id
82
+
83
+
84
+ class DataProcessor(object):
85
+ """Base class for converters for seq regression/classification datasets."""
86
+
87
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
88
+ self.process_text_fn = process_text_fn
89
+ self.is_regression = False
90
+ self.label_type = None
91
+
92
+ def get_train_examples(self, data_dir):
93
+ """Gets a collection of `InputExample`s for the train set."""
94
+ raise NotImplementedError()
95
+
96
+ def get_dev_examples(self, data_dir):
97
+ """Gets a collection of `InputExample`s for the dev set."""
98
+ raise NotImplementedError()
99
+
100
+ def get_test_examples(self, data_dir):
101
+ """Gets a collection of `InputExample`s for prediction."""
102
+ raise NotImplementedError()
103
+
104
+ def get_labels(self):
105
+ """Gets the list of labels for this data set."""
106
+ raise NotImplementedError()
107
+
108
+ @staticmethod
109
+ def get_processor_name():
110
+ """Gets the string identifier of the processor."""
111
+ raise NotImplementedError()
112
+
113
+ @classmethod
114
+ def _read_tsv(cls, input_file, quotechar=None):
115
+ """Reads a tab separated value file."""
116
+ with tf.io.gfile.GFile(input_file, "r") as f:
117
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
118
+ lines = []
119
+ for line in reader:
120
+ lines.append(line)
121
+ return lines
122
+
123
+ @classmethod
124
+ def _read_jsonl(cls, input_file):
125
+ """Reads a json line file."""
126
+ with tf.io.gfile.GFile(input_file, "r") as f:
127
+ lines = []
128
+ for json_str in f:
129
+ lines.append(json.loads(json_str))
130
+ return lines
131
+
132
+ def featurize_example(self, *kargs, **kwargs):
133
+ """Converts a single `InputExample` into a single `InputFeatures`."""
134
+ return convert_single_example(*kargs, **kwargs)
135
+
136
+
137
+ class DefaultGLUEDataProcessor(DataProcessor):
138
+ """Processor for the SuperGLUE dataset."""
139
+
140
+ def get_train_examples(self, data_dir):
141
+ """See base class."""
142
+ return self._create_examples_tfds("train")
143
+
144
+ def get_dev_examples(self, data_dir):
145
+ """See base class."""
146
+ return self._create_examples_tfds("validation")
147
+
148
+ def get_test_examples(self, data_dir):
149
+ """See base class."""
150
+ return self._create_examples_tfds("test")
151
+
152
+ def _create_examples_tfds(self, set_type):
153
+ """Creates examples for the training/dev/test sets."""
154
+ raise NotImplementedError()
155
+
156
+
157
+ class AxProcessor(DataProcessor):
158
+ """Processor for the AX dataset (GLUE diagnostics dataset)."""
159
+
160
+ def get_train_examples(self, data_dir):
161
+ """See base class."""
162
+ train_mnli_dataset = tfds.load(
163
+ "glue/mnli", split="train", try_gcs=True).as_numpy_iterator()
164
+ return self._create_examples_tfds(train_mnli_dataset, "train")
165
+
166
+ def get_dev_examples(self, data_dir):
167
+ """See base class."""
168
+ val_mnli_dataset = tfds.load(
169
+ "glue/mnli", split="validation_matched",
170
+ try_gcs=True).as_numpy_iterator()
171
+ return self._create_examples_tfds(val_mnli_dataset, "validation")
172
+
173
+ def get_test_examples(self, data_dir):
174
+ """See base class."""
175
+ test_ax_dataset = tfds.load(
176
+ "glue/ax", split="test", try_gcs=True).as_numpy_iterator()
177
+ return self._create_examples_tfds(test_ax_dataset, "test")
178
+
179
+ def get_labels(self):
180
+ """See base class."""
181
+ return ["contradiction", "entailment", "neutral"]
182
+
183
+ @staticmethod
184
+ def get_processor_name():
185
+ """See base class."""
186
+ return "AX"
187
+
188
+ def _create_examples_tfds(self, dataset, set_type):
189
+ """Creates examples for the training/dev/test sets."""
190
+ dataset = list(dataset)
191
+ dataset.sort(key=lambda x: x["idx"])
192
+ examples = []
193
+ for i, example in enumerate(dataset):
194
+ guid = "%s-%s" % (set_type, i)
195
+ label = "contradiction"
196
+ text_a = self.process_text_fn(example["hypothesis"])
197
+ text_b = self.process_text_fn(example["premise"])
198
+ if set_type != "test":
199
+ label = self.get_labels()[example["label"]]
200
+ examples.append(
201
+ InputExample(
202
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
203
+ weight=None))
204
+ return examples
205
+
206
+
207
+ class ColaProcessor(DefaultGLUEDataProcessor):
208
+ """Processor for the CoLA data set (GLUE version)."""
209
+
210
+ def get_labels(self):
211
+ """See base class."""
212
+ return ["0", "1"]
213
+
214
+ @staticmethod
215
+ def get_processor_name():
216
+ """See base class."""
217
+ return "COLA"
218
+
219
+ def _create_examples_tfds(self, set_type):
220
+ """Creates examples for the training/dev/test sets."""
221
+ dataset = tfds.load(
222
+ "glue/cola", split=set_type, try_gcs=True).as_numpy_iterator()
223
+ dataset = list(dataset)
224
+ dataset.sort(key=lambda x: x["idx"])
225
+ examples = []
226
+ for i, example in enumerate(dataset):
227
+ guid = "%s-%s" % (set_type, i)
228
+ label = "0"
229
+ text_a = self.process_text_fn(example["sentence"])
230
+ if set_type != "test":
231
+ label = str(example["label"])
232
+ examples.append(
233
+ InputExample(
234
+ guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
235
+ return examples
236
+
237
+
238
+ class ImdbProcessor(DataProcessor):
239
+ """Processor for the IMDb dataset."""
240
+
241
+ def get_labels(self):
242
+ return ["neg", "pos"]
243
+
244
+ def get_train_examples(self, data_dir):
245
+ return self._create_examples(os.path.join(data_dir, "train"))
246
+
247
+ def get_dev_examples(self, data_dir):
248
+ return self._create_examples(os.path.join(data_dir, "test"))
249
+
250
+ @staticmethod
251
+ def get_processor_name():
252
+ """See base class."""
253
+ return "IMDB"
254
+
255
+ def _create_examples(self, data_dir):
256
+ """Creates examples."""
257
+ examples = []
258
+ for label in ["neg", "pos"]:
259
+ cur_dir = os.path.join(data_dir, label)
260
+ for filename in tf.io.gfile.listdir(cur_dir):
261
+ if not filename.endswith("txt"):
262
+ continue
263
+
264
+ if len(examples) % 1000 == 0:
265
+ logging.info("Loading dev example %d", len(examples))
266
+
267
+ path = os.path.join(cur_dir, filename)
268
+ with tf.io.gfile.GFile(path, "r") as f:
269
+ text = f.read().strip().replace("<br />", " ")
270
+ examples.append(
271
+ InputExample(
272
+ guid="unused_id", text_a=text, text_b=None, label=label))
273
+ return examples
274
+
275
+
276
+ class MnliProcessor(DataProcessor):
277
+ """Processor for the MultiNLI data set (GLUE version)."""
278
+
279
+ def __init__(self,
280
+ mnli_type="matched",
281
+ process_text_fn=tokenization.convert_to_unicode):
282
+ super(MnliProcessor, self).__init__(process_text_fn)
283
+ self.dataset = tfds.load("glue/mnli", try_gcs=True)
284
+ if mnli_type not in ("matched", "mismatched"):
285
+ raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
286
+ self.mnli_type = mnli_type
287
+
288
+ def get_train_examples(self, data_dir):
289
+ """See base class."""
290
+ return self._create_examples_tfds("train")
291
+
292
+ def get_dev_examples(self, data_dir):
293
+ """See base class."""
294
+ if self.mnli_type == "matched":
295
+ return self._create_examples_tfds("validation_matched")
296
+ else:
297
+ return self._create_examples_tfds("validation_mismatched")
298
+
299
+ def get_test_examples(self, data_dir):
300
+ """See base class."""
301
+ if self.mnli_type == "matched":
302
+ return self._create_examples_tfds("test_matched")
303
+ else:
304
+ return self._create_examples_tfds("test_mismatched")
305
+
306
+ def get_labels(self):
307
+ """See base class."""
308
+ return ["contradiction", "entailment", "neutral"]
309
+
310
+ @staticmethod
311
+ def get_processor_name():
312
+ """See base class."""
313
+ return "MNLI"
314
+
315
+ def _create_examples_tfds(self, set_type):
316
+ """Creates examples for the training/dev/test sets."""
317
+ dataset = tfds.load(
318
+ "glue/mnli", split=set_type, try_gcs=True).as_numpy_iterator()
319
+ dataset = list(dataset)
320
+ dataset.sort(key=lambda x: x["idx"])
321
+ examples = []
322
+ for i, example in enumerate(dataset):
323
+ guid = "%s-%s" % (set_type, i)
324
+ label = "contradiction"
325
+ text_a = self.process_text_fn(example["hypothesis"])
326
+ text_b = self.process_text_fn(example["premise"])
327
+ if set_type != "test":
328
+ label = self.get_labels()[example["label"]]
329
+ examples.append(
330
+ InputExample(
331
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
332
+ weight=None))
333
+ return examples
334
+
335
+
336
+ class MrpcProcessor(DefaultGLUEDataProcessor):
337
+ """Processor for the MRPC data set (GLUE version)."""
338
+
339
+ def get_labels(self):
340
+ """See base class."""
341
+ return ["0", "1"]
342
+
343
+ @staticmethod
344
+ def get_processor_name():
345
+ """See base class."""
346
+ return "MRPC"
347
+
348
+ def _create_examples_tfds(self, set_type):
349
+ """Creates examples for the training/dev/test sets."""
350
+ dataset = tfds.load(
351
+ "glue/mrpc", split=set_type, try_gcs=True).as_numpy_iterator()
352
+ dataset = list(dataset)
353
+ dataset.sort(key=lambda x: x["idx"])
354
+ examples = []
355
+ for i, example in enumerate(dataset):
356
+ guid = "%s-%s" % (set_type, i)
357
+ label = "0"
358
+ text_a = self.process_text_fn(example["sentence1"])
359
+ text_b = self.process_text_fn(example["sentence2"])
360
+ if set_type != "test":
361
+ label = str(example["label"])
362
+ examples.append(
363
+ InputExample(
364
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
365
+ weight=None))
366
+ return examples
367
+
368
+
369
+ class PawsxProcessor(DataProcessor):
370
+ """Processor for the PAWS-X data set."""
371
+ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
372
+
373
+ def __init__(self,
374
+ language="en",
375
+ process_text_fn=tokenization.convert_to_unicode):
376
+ super(PawsxProcessor, self).__init__(process_text_fn)
377
+ if language == "all":
378
+ self.languages = PawsxProcessor.supported_languages
379
+ elif language not in PawsxProcessor.supported_languages:
380
+ raise ValueError("language %s is not supported for PAWS-X task." %
381
+ language)
382
+ else:
383
+ self.languages = [language]
384
+
385
+ def get_train_examples(self, data_dir):
386
+ """See base class."""
387
+ lines = []
388
+ for language in self.languages:
389
+ if language == "en":
390
+ train_tsv = "train.tsv"
391
+ else:
392
+ train_tsv = "translated_train.tsv"
393
+ # Skips the header.
394
+ lines.extend(
395
+ self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
396
+
397
+ examples = []
398
+ for i, line in enumerate(lines):
399
+ guid = "train-%d" % i
400
+ text_a = self.process_text_fn(line[1])
401
+ text_b = self.process_text_fn(line[2])
402
+ label = self.process_text_fn(line[3])
403
+ examples.append(
404
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
405
+ return examples
406
+
407
+ def get_dev_examples(self, data_dir):
408
+ """See base class."""
409
+ lines = []
410
+ for lang in PawsxProcessor.supported_languages:
411
+ lines.extend(
412
+ self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
413
+
414
+ examples = []
415
+ for i, line in enumerate(lines):
416
+ guid = "dev-%d" % i
417
+ text_a = self.process_text_fn(line[1])
418
+ text_b = self.process_text_fn(line[2])
419
+ label = self.process_text_fn(line[3])
420
+ examples.append(
421
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
422
+ return examples
423
+
424
+ def get_test_examples(self, data_dir):
425
+ """See base class."""
426
+ examples_by_lang = {k: [] for k in self.supported_languages}
427
+ for lang in self.supported_languages:
428
+ lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
429
+ for i, line in enumerate(lines):
430
+ guid = "test-%d" % i
431
+ text_a = self.process_text_fn(line[1])
432
+ text_b = self.process_text_fn(line[2])
433
+ label = self.process_text_fn(line[3])
434
+ examples_by_lang[lang].append(
435
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
436
+ return examples_by_lang
437
+
438
+ def get_labels(self):
439
+ """See base class."""
440
+ return ["0", "1"]
441
+
442
+ @staticmethod
443
+ def get_processor_name():
444
+ """See base class."""
445
+ return "XTREME-PAWS-X"
446
+
447
+
448
+ class QnliProcessor(DefaultGLUEDataProcessor):
449
+ """Processor for the QNLI data set (GLUE version)."""
450
+
451
+ def get_labels(self):
452
+ """See base class."""
453
+ return ["entailment", "not_entailment"]
454
+
455
+ @staticmethod
456
+ def get_processor_name():
457
+ """See base class."""
458
+ return "QNLI"
459
+
460
+ def _create_examples_tfds(self, set_type):
461
+ """Creates examples for the training/dev/test sets."""
462
+ dataset = tfds.load(
463
+ "glue/qnli", split=set_type, try_gcs=True).as_numpy_iterator()
464
+ dataset = list(dataset)
465
+ dataset.sort(key=lambda x: x["idx"])
466
+ examples = []
467
+ for i, example in enumerate(dataset):
468
+ guid = "%s-%s" % (set_type, i)
469
+ label = "entailment"
470
+ text_a = self.process_text_fn(example["question"])
471
+ text_b = self.process_text_fn(example["sentence"])
472
+ if set_type != "test":
473
+ label = self.get_labels()[example["label"]]
474
+ examples.append(
475
+ InputExample(
476
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
477
+ weight=None))
478
+ return examples
479
+
480
+
481
+ class QqpProcessor(DefaultGLUEDataProcessor):
482
+ """Processor for the QQP data set (GLUE version)."""
483
+
484
+ def get_labels(self):
485
+ """See base class."""
486
+ return ["0", "1"]
487
+
488
+ @staticmethod
489
+ def get_processor_name():
490
+ """See base class."""
491
+ return "QQP"
492
+
493
+ def _create_examples_tfds(self, set_type):
494
+ """Creates examples for the training/dev/test sets."""
495
+ dataset = tfds.load(
496
+ "glue/qqp", split=set_type, try_gcs=True).as_numpy_iterator()
497
+ dataset = list(dataset)
498
+ dataset.sort(key=lambda x: x["idx"])
499
+ examples = []
500
+ for i, example in enumerate(dataset):
501
+ guid = "%s-%s" % (set_type, i)
502
+ label = "0"
503
+ text_a = self.process_text_fn(example["question1"])
504
+ text_b = self.process_text_fn(example["question2"])
505
+ if set_type != "test":
506
+ label = str(example["label"])
507
+ examples.append(
508
+ InputExample(
509
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
510
+ weight=None))
511
+ return examples
512
+
513
+
514
+ class RteProcessor(DefaultGLUEDataProcessor):
515
+ """Processor for the RTE data set (GLUE version)."""
516
+
517
+ def get_labels(self):
518
+ """See base class."""
519
+ # All datasets are converted to 2-class split, where for 3-class datasets we
520
+ # collapse neutral and contradiction into not_entailment.
521
+ return ["entailment", "not_entailment"]
522
+
523
+ @staticmethod
524
+ def get_processor_name():
525
+ """See base class."""
526
+ return "RTE"
527
+
528
+ def _create_examples_tfds(self, set_type):
529
+ """Creates examples for the training/dev/test sets."""
530
+ dataset = tfds.load(
531
+ "glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
532
+ dataset = list(dataset)
533
+ dataset.sort(key=lambda x: x["idx"])
534
+ examples = []
535
+ for i, example in enumerate(dataset):
536
+ guid = "%s-%s" % (set_type, i)
537
+ label = "entailment"
538
+ text_a = self.process_text_fn(example["sentence1"])
539
+ text_b = self.process_text_fn(example["sentence2"])
540
+ if set_type != "test":
541
+ label = self.get_labels()[example["label"]]
542
+ examples.append(
543
+ InputExample(
544
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
545
+ weight=None))
546
+ return examples
547
+
548
+
549
+ class SstProcessor(DefaultGLUEDataProcessor):
550
+ """Processor for the SST-2 data set (GLUE version)."""
551
+
552
+ def get_labels(self):
553
+ """See base class."""
554
+ return ["0", "1"]
555
+
556
+ @staticmethod
557
+ def get_processor_name():
558
+ """See base class."""
559
+ return "SST-2"
560
+
561
+ def _create_examples_tfds(self, set_type):
562
+ """Creates examples for the training/dev/test sets."""
563
+ dataset = tfds.load(
564
+ "glue/sst2", split=set_type, try_gcs=True).as_numpy_iterator()
565
+ dataset = list(dataset)
566
+ dataset.sort(key=lambda x: x["idx"])
567
+ examples = []
568
+ for i, example in enumerate(dataset):
569
+ guid = "%s-%s" % (set_type, i)
570
+ label = "0"
571
+ text_a = self.process_text_fn(example["sentence"])
572
+ if set_type != "test":
573
+ label = str(example["label"])
574
+ examples.append(
575
+ InputExample(
576
+ guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
577
+ return examples
578
+
579
+
580
+ class StsBProcessor(DefaultGLUEDataProcessor):
581
+ """Processor for the STS-B data set (GLUE version)."""
582
+
583
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
584
+ super(StsBProcessor, self).__init__(process_text_fn=process_text_fn)
585
+ self.is_regression = True
586
+ self.label_type = float
587
+ self._labels = None
588
+
589
+ def _create_examples_tfds(self, set_type):
590
+ """Creates examples for the training/dev/test sets."""
591
+ dataset = tfds.load(
592
+ "glue/stsb", split=set_type, try_gcs=True).as_numpy_iterator()
593
+ dataset = list(dataset)
594
+ dataset.sort(key=lambda x: x["idx"])
595
+ examples = []
596
+ for i, example in enumerate(dataset):
597
+ guid = "%s-%s" % (set_type, i)
598
+ label = 0.0
599
+ text_a = self.process_text_fn(example["sentence1"])
600
+ text_b = self.process_text_fn(example["sentence2"])
601
+ if set_type != "test":
602
+ label = self.label_type(example["label"])
603
+ examples.append(
604
+ InputExample(
605
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
606
+ weight=None))
607
+ return examples
608
+
609
+ def get_labels(self):
610
+ """See base class."""
611
+ return self._labels
612
+
613
+ @staticmethod
614
+ def get_processor_name():
615
+ """See base class."""
616
+ return "STS-B"
617
+
618
+
619
+ class TfdsProcessor(DataProcessor):
620
+ """Processor for generic text classification and regression TFDS data set.
621
+
622
+ The TFDS parameters are expected to be provided in the tfds_params string, in
623
+ a comma-separated list of parameter assignments.
624
+ Examples:
625
+ tfds_params="dataset=scicite,text_key=string"
626
+ tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
627
+ tfds_params="dataset=glue/cola,text_key=sentence"
628
+ tfds_params="dataset=glue/sst2,text_key=sentence"
629
+ tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
630
+ tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
631
+ tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
632
+ "is_regression=true,label_type=float"
633
+ tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
634
+ "skip_label=-1"
635
+ Possible parameters (please refer to the documentation of Tensorflow Datasets
636
+ (TFDS) for the meaning of individual parameters):
637
+ dataset: Required dataset name (potentially with subset and version number).
638
+ data_dir: Optional TFDS source root directory.
639
+ module_import: Optional Dataset module to import.
640
+ train_split: Name of the train split (defaults to `train`).
641
+ dev_split: Name of the dev split (defaults to `validation`).
642
+ test_split: Name of the test split (defaults to `test`).
643
+ text_key: Key of the text_a feature (defaults to `text`).
644
+ text_b_key: Key of the second text feature if available.
645
+ label_key: Key of the label feature (defaults to `label`).
646
+ test_text_key: Key of the text feature to use in test set.
647
+ test_text_b_key: Key of the second text feature to use in test set.
648
+ test_label: String to be used as the label for all test examples.
649
+ label_type: Type of the label key (defaults to `int`).
650
+ weight_key: Key of the float sample weight (is not used if not provided).
651
+ is_regression: Whether the task is a regression problem (defaults to False).
652
+ skip_label: Skip examples with given label (defaults to None).
653
+ """
654
+
655
+ def __init__(self,
656
+ tfds_params,
657
+ process_text_fn=tokenization.convert_to_unicode):
658
+ super(TfdsProcessor, self).__init__(process_text_fn)
659
+ self._process_tfds_params_str(tfds_params)
660
+ if self.module_import:
661
+ importlib.import_module(self.module_import)
662
+
663
+ self.dataset, info = tfds.load(
664
+ self.dataset_name, data_dir=self.data_dir, with_info=True)
665
+ if self.is_regression:
666
+ self._labels = None
667
+ else:
668
+ self._labels = list(range(info.features[self.label_key].num_classes))
669
+
670
+ def _process_tfds_params_str(self, params_str):
671
+ """Extracts TFDS parameters from a comma-separated assignments string."""
672
+ dtype_map = {"int": int, "float": float}
673
+ cast_str_to_bool = lambda s: s.lower() not in ["false", "0"]
674
+
675
+ tuples = [x.split("=") for x in params_str.split(",")]
676
+ d = {k.strip(): v.strip() for k, v in tuples}
677
+ self.dataset_name = d["dataset"] # Required.
678
+ self.data_dir = d.get("data_dir", None)
679
+ self.module_import = d.get("module_import", None)
680
+ self.train_split = d.get("train_split", "train")
681
+ self.dev_split = d.get("dev_split", "validation")
682
+ self.test_split = d.get("test_split", "test")
683
+ self.text_key = d.get("text_key", "text")
684
+ self.text_b_key = d.get("text_b_key", None)
685
+ self.label_key = d.get("label_key", "label")
686
+ self.test_text_key = d.get("test_text_key", self.text_key)
687
+ self.test_text_b_key = d.get("test_text_b_key", self.text_b_key)
688
+ self.test_label = d.get("test_label", "test_example")
689
+ self.label_type = dtype_map[d.get("label_type", "int")]
690
+ self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
691
+ self.weight_key = d.get("weight_key", None)
692
+ self.skip_label = d.get("skip_label", None)
693
+ if self.skip_label is not None:
694
+ self.skip_label = self.label_type(self.skip_label)
695
+
696
+ def get_train_examples(self, data_dir):
697
+ assert data_dir is None
698
+ return self._create_examples(self.train_split, "train")
699
+
700
+ def get_dev_examples(self, data_dir):
701
+ assert data_dir is None
702
+ return self._create_examples(self.dev_split, "dev")
703
+
704
+ def get_test_examples(self, data_dir):
705
+ assert data_dir is None
706
+ return self._create_examples(self.test_split, "test")
707
+
708
+ def get_labels(self):
709
+ return self._labels
710
+
711
+ def get_processor_name(self):
712
+ return "TFDS_" + self.dataset_name
713
+
714
+ def _create_examples(self, split_name, set_type):
715
+ """Creates examples for the training/dev/test sets."""
716
+ if split_name not in self.dataset:
717
+ raise ValueError("Split {} not available.".format(split_name))
718
+ dataset = self.dataset[split_name].as_numpy_iterator()
719
+ examples = []
720
+ text_b, weight = None, None
721
+ for i, example in enumerate(dataset):
722
+ guid = "%s-%s" % (set_type, i)
723
+ if set_type == "test":
724
+ text_a = self.process_text_fn(example[self.test_text_key])
725
+ if self.test_text_b_key:
726
+ text_b = self.process_text_fn(example[self.test_text_b_key])
727
+ label = self.test_label
728
+ else:
729
+ text_a = self.process_text_fn(example[self.text_key])
730
+ if self.text_b_key:
731
+ text_b = self.process_text_fn(example[self.text_b_key])
732
+ label = self.label_type(example[self.label_key])
733
+ if self.skip_label is not None and label == self.skip_label:
734
+ continue
735
+ if self.weight_key:
736
+ weight = float(example[self.weight_key])
737
+ examples.append(
738
+ InputExample(
739
+ guid=guid,
740
+ text_a=text_a,
741
+ text_b=text_b,
742
+ label=label,
743
+ weight=weight))
744
+ return examples
745
+
746
+
747
+ class WnliProcessor(DefaultGLUEDataProcessor):
748
+ """Processor for the WNLI data set (GLUE version)."""
749
+
750
+ def get_labels(self):
751
+ """See base class."""
752
+ return ["0", "1"]
753
+
754
+ @staticmethod
755
+ def get_processor_name():
756
+ """See base class."""
757
+ return "WNLI"
758
+
759
+ def _create_examples_tfds(self, set_type):
760
+ """Creates examples for the training/dev/test sets."""
761
+ dataset = tfds.load(
762
+ "glue/wnli", split=set_type, try_gcs=True).as_numpy_iterator()
763
+ dataset = list(dataset)
764
+ dataset.sort(key=lambda x: x["idx"])
765
+ examples = []
766
+ for i, example in enumerate(dataset):
767
+ guid = "%s-%s" % (set_type, i)
768
+ label = "0"
769
+ text_a = self.process_text_fn(example["sentence1"])
770
+ text_b = self.process_text_fn(example["sentence2"])
771
+ if set_type != "test":
772
+ label = str(example["label"])
773
+ examples.append(
774
+ InputExample(
775
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
776
+ weight=None))
777
+ return examples
778
+
779
+
780
+ class XnliProcessor(DataProcessor):
781
+ """Processor for the XNLI data set."""
782
+ supported_languages = [
783
+ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
784
+ "ur", "vi", "zh"
785
+ ]
786
+
787
+ def __init__(self,
788
+ language="en",
789
+ process_text_fn=tokenization.convert_to_unicode):
790
+ super(XnliProcessor, self).__init__(process_text_fn)
791
+ if language == "all":
792
+ self.languages = XnliProcessor.supported_languages
793
+ elif language not in XnliProcessor.supported_languages:
794
+ raise ValueError("language %s is not supported for XNLI task." % language)
795
+ else:
796
+ self.languages = [language]
797
+
798
+ def get_train_examples(self, data_dir):
799
+ """See base class."""
800
+ lines = []
801
+ for language in self.languages:
802
+ # Skips the header.
803
+ lines.extend(
804
+ self._read_tsv(
805
+ os.path.join(data_dir, "multinli",
806
+ "multinli.train.%s.tsv" % language))[1:])
807
+
808
+ examples = []
809
+ for i, line in enumerate(lines):
810
+ guid = "train-%d" % i
811
+ text_a = self.process_text_fn(line[0])
812
+ text_b = self.process_text_fn(line[1])
813
+ label = self.process_text_fn(line[2])
814
+ if label == self.process_text_fn("contradictory"):
815
+ label = self.process_text_fn("contradiction")
816
+ examples.append(
817
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
818
+ return examples
819
+
820
+ def get_dev_examples(self, data_dir):
821
+ """See base class."""
822
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
823
+ examples = []
824
+ for i, line in enumerate(lines):
825
+ if i == 0:
826
+ continue
827
+ guid = "dev-%d" % i
828
+ text_a = self.process_text_fn(line[6])
829
+ text_b = self.process_text_fn(line[7])
830
+ label = self.process_text_fn(line[1])
831
+ examples.append(
832
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
833
+ return examples
834
+
835
+ def get_test_examples(self, data_dir):
836
+ """See base class."""
837
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
838
+ examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
839
+ for i, line in enumerate(lines):
840
+ if i == 0:
841
+ continue
842
+ guid = "test-%d" % i
843
+ language = self.process_text_fn(line[0])
844
+ text_a = self.process_text_fn(line[6])
845
+ text_b = self.process_text_fn(line[7])
846
+ label = self.process_text_fn(line[1])
847
+ examples_by_lang[language].append(
848
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
849
+ return examples_by_lang
850
+
851
+ def get_labels(self):
852
+ """See base class."""
853
+ return ["contradiction", "entailment", "neutral"]
854
+
855
+ @staticmethod
856
+ def get_processor_name():
857
+ """See base class."""
858
+ return "XNLI"
859
+
860
+
861
+ class XtremePawsxProcessor(DataProcessor):
862
+ """Processor for the XTREME PAWS-X data set."""
863
+ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
864
+
865
+ def __init__(self,
866
+ process_text_fn=tokenization.convert_to_unicode,
867
+ translated_data_dir=None,
868
+ only_use_en_dev=True):
869
+ """See base class.
870
+
871
+ Args:
872
+ process_text_fn: See base class.
873
+ translated_data_dir: If specified, will also include translated data in
874
+ the training and testing data.
875
+ only_use_en_dev: If True, only use english dev data. Otherwise, use dev
876
+ data from all languages.
877
+ """
878
+ super(XtremePawsxProcessor, self).__init__(process_text_fn)
879
+ self.translated_data_dir = translated_data_dir
880
+ self.only_use_en_dev = only_use_en_dev
881
+
882
+ def get_train_examples(self, data_dir):
883
+ """See base class."""
884
+ examples = []
885
+ if self.translated_data_dir is None:
886
+ lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
887
+ for i, line in enumerate(lines):
888
+ guid = "train-%d" % i
889
+ text_a = self.process_text_fn(line[0])
890
+ text_b = self.process_text_fn(line[1])
891
+ label = self.process_text_fn(line[2])
892
+ examples.append(
893
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
894
+ else:
895
+ for lang in self.supported_languages:
896
+ lines = self._read_tsv(
897
+ os.path.join(self.translated_data_dir, "translate-train",
898
+ f"en-{lang}-translated.tsv"))
899
+ for i, line in enumerate(lines):
900
+ guid = f"train-{lang}-{i}"
901
+ text_a = self.process_text_fn(line[2])
902
+ text_b = self.process_text_fn(line[3])
903
+ label = self.process_text_fn(line[4])
904
+ examples.append(
905
+ InputExample(
906
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
907
+ return examples
908
+
909
+ def get_dev_examples(self, data_dir):
910
+ """See base class."""
911
+ examples = []
912
+ if self.only_use_en_dev:
913
+ lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
914
+ for i, line in enumerate(lines):
915
+ guid = "dev-%d" % i
916
+ text_a = self.process_text_fn(line[0])
917
+ text_b = self.process_text_fn(line[1])
918
+ label = self.process_text_fn(line[2])
919
+ examples.append(
920
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
921
+ else:
922
+ for lang in self.supported_languages:
923
+ lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
924
+ for i, line in enumerate(lines):
925
+ guid = f"dev-{lang}-{i}"
926
+ text_a = self.process_text_fn(line[0])
927
+ text_b = self.process_text_fn(line[1])
928
+ label = self.process_text_fn(line[2])
929
+ examples.append(
930
+ InputExample(
931
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
932
+ return examples
933
+
934
+ def get_test_examples(self, data_dir):
935
+ """See base class."""
936
+ examples_by_lang = {}
937
+ for lang in self.supported_languages:
938
+ examples_by_lang[lang] = []
939
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
940
+ for i, line in enumerate(lines):
941
+ guid = f"test-{lang}-{i}"
942
+ text_a = self.process_text_fn(line[0])
943
+ text_b = self.process_text_fn(line[1])
944
+ label = "0"
945
+ examples_by_lang[lang].append(
946
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
947
+ if self.translated_data_dir is not None:
948
+ for lang in self.supported_languages:
949
+ if lang == "en":
950
+ continue
951
+ examples_by_lang[f"{lang}-en"] = []
952
+ lines = self._read_tsv(
953
+ os.path.join(self.translated_data_dir, "translate-test",
954
+ f"test-{lang}-en-translated.tsv"))
955
+ for i, line in enumerate(lines):
956
+ guid = f"test-{lang}-en-{i}"
957
+ text_a = self.process_text_fn(line[2])
958
+ text_b = self.process_text_fn(line[3])
959
+ label = "0"
960
+ examples_by_lang[f"{lang}-en"].append(
961
+ InputExample(
962
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
963
+ return examples_by_lang
964
+
965
+ def get_labels(self):
966
+ """See base class."""
967
+ return ["0", "1"]
968
+
969
+ @staticmethod
970
+ def get_processor_name():
971
+ """See base class."""
972
+ return "XTREME-PAWS-X"
973
+
974
+
975
+ class XtremeXnliProcessor(DataProcessor):
976
+ """Processor for the XTREME XNLI data set."""
977
+ supported_languages = [
978
+ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
979
+ "ur", "vi", "zh"
980
+ ]
981
+
982
+ def __init__(self,
983
+ process_text_fn=tokenization.convert_to_unicode,
984
+ translated_data_dir=None,
985
+ only_use_en_dev=True):
986
+ """See base class.
987
+
988
+ Args:
989
+ process_text_fn: See base class.
990
+ translated_data_dir: If specified, will also include translated data in
991
+ the training data.
992
+ only_use_en_dev: If True, only use english dev data. Otherwise, use dev
993
+ data from all languages.
994
+ """
995
+ super(XtremeXnliProcessor, self).__init__(process_text_fn)
996
+ self.translated_data_dir = translated_data_dir
997
+ self.only_use_en_dev = only_use_en_dev
998
+
999
+ def get_train_examples(self, data_dir):
1000
+ """See base class."""
1001
+ lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
1002
+
1003
+ examples = []
1004
+ if self.translated_data_dir is None:
1005
+ for i, line in enumerate(lines):
1006
+ guid = "train-%d" % i
1007
+ text_a = self.process_text_fn(line[0])
1008
+ text_b = self.process_text_fn(line[1])
1009
+ label = self.process_text_fn(line[2])
1010
+ if label == self.process_text_fn("contradictory"):
1011
+ label = self.process_text_fn("contradiction")
1012
+ examples.append(
1013
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1014
+ else:
1015
+ for lang in self.supported_languages:
1016
+ lines = self._read_tsv(
1017
+ os.path.join(self.translated_data_dir, "translate-train",
1018
+ f"en-{lang}-translated.tsv"))
1019
+ for i, line in enumerate(lines):
1020
+ guid = f"train-{lang}-{i}"
1021
+ text_a = self.process_text_fn(line[2])
1022
+ text_b = self.process_text_fn(line[3])
1023
+ label = self.process_text_fn(line[4])
1024
+ if label == self.process_text_fn("contradictory"):
1025
+ label = self.process_text_fn("contradiction")
1026
+ examples.append(
1027
+ InputExample(
1028
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
1029
+ return examples
1030
+
1031
+ def get_dev_examples(self, data_dir):
1032
+ """See base class."""
1033
+ examples = []
1034
+ if self.only_use_en_dev:
1035
+ lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
1036
+ for i, line in enumerate(lines):
1037
+ guid = "dev-%d" % i
1038
+ text_a = self.process_text_fn(line[0])
1039
+ text_b = self.process_text_fn(line[1])
1040
+ label = self.process_text_fn(line[2])
1041
+ examples.append(
1042
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1043
+ else:
1044
+ for lang in self.supported_languages:
1045
+ lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
1046
+ for i, line in enumerate(lines):
1047
+ guid = f"dev-{lang}-{i}"
1048
+ text_a = self.process_text_fn(line[0])
1049
+ text_b = self.process_text_fn(line[1])
1050
+ label = self.process_text_fn(line[2])
1051
+ if label == self.process_text_fn("contradictory"):
1052
+ label = self.process_text_fn("contradiction")
1053
+ examples.append(
1054
+ InputExample(
1055
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
1056
+ return examples
1057
+
1058
+ def get_test_examples(self, data_dir):
1059
+ """See base class."""
1060
+ examples_by_lang = {}
1061
+ for lang in self.supported_languages:
1062
+ examples_by_lang[lang] = []
1063
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
1064
+ for i, line in enumerate(lines):
1065
+ guid = f"test-{lang}-{i}"
1066
+ text_a = self.process_text_fn(line[0])
1067
+ text_b = self.process_text_fn(line[1])
1068
+ label = "contradiction"
1069
+ examples_by_lang[lang].append(
1070
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1071
+ if self.translated_data_dir is not None:
1072
+ for lang in self.supported_languages:
1073
+ if lang == "en":
1074
+ continue
1075
+ examples_by_lang[f"{lang}-en"] = []
1076
+ lines = self._read_tsv(
1077
+ os.path.join(self.translated_data_dir, "translate-test",
1078
+ f"test-{lang}-en-translated.tsv"))
1079
+ for i, line in enumerate(lines):
1080
+ guid = f"test-{lang}-en-{i}"
1081
+ text_a = self.process_text_fn(line[2])
1082
+ text_b = self.process_text_fn(line[3])
1083
+ label = "contradiction"
1084
+ examples_by_lang[f"{lang}-en"].append(
1085
+ InputExample(
1086
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
1087
+ return examples_by_lang
1088
+
1089
+ def get_labels(self):
1090
+ """See base class."""
1091
+ return ["contradiction", "entailment", "neutral"]
1092
+
1093
+ @staticmethod
1094
+ def get_processor_name():
1095
+ """See base class."""
1096
+ return "XTREME-XNLI"
1097
+
1098
+
1099
+ def convert_single_example(ex_index, example, label_list, max_seq_length,
1100
+ tokenizer):
1101
+ """Converts a single `InputExample` into a single `InputFeatures`."""
1102
+ label_map = {}
1103
+ if label_list:
1104
+ for (i, label) in enumerate(label_list):
1105
+ label_map[label] = i
1106
+
1107
+ tokens_a = tokenizer.tokenize(example.text_a)
1108
+ tokens_b = None
1109
+ if example.text_b:
1110
+ tokens_b = tokenizer.tokenize(example.text_b)
1111
+
1112
+ if tokens_b:
1113
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
1114
+ # length is less than the specified length.
1115
+ # Account for [CLS], [SEP], [SEP] with "- 3"
1116
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
1117
+ else:
1118
+ # Account for [CLS] and [SEP] with "- 2"
1119
+ if len(tokens_a) > max_seq_length - 2:
1120
+ tokens_a = tokens_a[0:(max_seq_length - 2)]
1121
+
1122
+ seg_id_a = 0
1123
+ seg_id_b = 1
1124
+ seg_id_cls = 0
1125
+ seg_id_pad = 0
1126
+
1127
+ # The convention in BERT is:
1128
+ # (a) For sequence pairs:
1129
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
1130
+ # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
1131
+ # (b) For single sequences:
1132
+ # tokens: [CLS] the dog is hairy . [SEP]
1133
+ # type_ids: 0 0 0 0 0 0 0
1134
+ #
1135
+ # Where "type_ids" are used to indicate whether this is the first
1136
+ # sequence or the second sequence. The embedding vectors for `type=0` and
1137
+ # `type=1` were learned during pre-training and are added to the wordpiece
1138
+ # embedding vector (and position vector). This is not *strictly* necessary
1139
+ # since the [SEP] token unambiguously separates the sequences, but it makes
1140
+ # it easier for the model to learn the concept of sequences.
1141
+ #
1142
+ # For classification tasks, the first vector (corresponding to [CLS]) is
1143
+ # used as the "sentence vector". Note that this only makes sense because
1144
+ # the entire model is fine-tuned.
1145
+ tokens = []
1146
+ segment_ids = []
1147
+ tokens.append("[CLS]")
1148
+ segment_ids.append(seg_id_cls)
1149
+ for token in tokens_a:
1150
+ tokens.append(token)
1151
+ segment_ids.append(seg_id_a)
1152
+ tokens.append("[SEP]")
1153
+ segment_ids.append(seg_id_a)
1154
+
1155
+ if tokens_b:
1156
+ for token in tokens_b:
1157
+ tokens.append(token)
1158
+ segment_ids.append(seg_id_b)
1159
+ tokens.append("[SEP]")
1160
+ segment_ids.append(seg_id_b)
1161
+
1162
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
1163
+
1164
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
1165
+ # tokens are attended to.
1166
+ input_mask = [1] * len(input_ids)
1167
+
1168
+ # Zero-pad up to the sequence length.
1169
+ while len(input_ids) < max_seq_length:
1170
+ input_ids.append(0)
1171
+ input_mask.append(0)
1172
+ segment_ids.append(seg_id_pad)
1173
+
1174
+ assert len(input_ids) == max_seq_length
1175
+ assert len(input_mask) == max_seq_length
1176
+ assert len(segment_ids) == max_seq_length
1177
+
1178
+ label_id = label_map[example.label] if label_map else example.label
1179
+ if ex_index < 5:
1180
+ logging.info("*** Example ***")
1181
+ logging.info("guid: %s", (example.guid))
1182
+ logging.info("tokens: %s",
1183
+ " ".join([tokenization.printable_text(x) for x in tokens]))
1184
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
1185
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
1186
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
1187
+ logging.info("label: %s (id = %s)", example.label, str(label_id))
1188
+ logging.info("weight: %s", example.weight)
1189
+ logging.info("example_id: %s", example.example_id)
1190
+
1191
+ feature = InputFeatures(
1192
+ input_ids=input_ids,
1193
+ input_mask=input_mask,
1194
+ segment_ids=segment_ids,
1195
+ label_id=label_id,
1196
+ is_real_example=True,
1197
+ weight=example.weight,
1198
+ example_id=example.example_id)
1199
+
1200
+ return feature
1201
+
1202
+
1203
+ class AXgProcessor(DataProcessor):
1204
+ """Processor for the AXg dataset (SuperGLUE diagnostics dataset)."""
1205
+
1206
+ def get_test_examples(self, data_dir):
1207
+ """See base class."""
1208
+ return self._create_examples(
1209
+ self._read_jsonl(os.path.join(data_dir, "AX-g.jsonl")), "test")
1210
+
1211
+ def get_labels(self):
1212
+ """See base class."""
1213
+ return ["entailment", "not_entailment"]
1214
+
1215
+ @staticmethod
1216
+ def get_processor_name():
1217
+ """See base class."""
1218
+ return "AXg"
1219
+
1220
+ def _create_examples(self, lines, set_type):
1221
+ """Creates examples for the training/dev/test sets."""
1222
+ examples = []
1223
+ for line in lines:
1224
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
1225
+ text_a = self.process_text_fn(line["premise"])
1226
+ text_b = self.process_text_fn(line["hypothesis"])
1227
+ label = self.process_text_fn(line["label"])
1228
+ examples.append(
1229
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1230
+ return examples
1231
+
1232
+
1233
+ class BoolQProcessor(DefaultGLUEDataProcessor):
1234
+ """Processor for the BoolQ dataset (SuperGLUE diagnostics dataset)."""
1235
+
1236
+ def get_labels(self):
1237
+ """See base class."""
1238
+ return ["True", "False"]
1239
+
1240
+ @staticmethod
1241
+ def get_processor_name():
1242
+ """See base class."""
1243
+ return "BoolQ"
1244
+
1245
+ def _create_examples_tfds(self, set_type):
1246
+ """Creates examples for the training/dev/test sets."""
1247
+ dataset = tfds.load(
1248
+ "super_glue/boolq", split=set_type, try_gcs=True).as_numpy_iterator()
1249
+ examples = []
1250
+ for example in dataset:
1251
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1252
+ text_a = self.process_text_fn(example["question"])
1253
+ text_b = self.process_text_fn(example["passage"])
1254
+ label = "False"
1255
+ if set_type != "test":
1256
+ label = self.get_labels()[example["label"]]
1257
+ examples.append(
1258
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1259
+ return examples
1260
+
1261
+
1262
+ class CBProcessor(DefaultGLUEDataProcessor):
1263
+ """Processor for the CB dataset (SuperGLUE diagnostics dataset)."""
1264
+
1265
+ def get_labels(self):
1266
+ """See base class."""
1267
+ return ["entailment", "neutral", "contradiction"]
1268
+
1269
+ @staticmethod
1270
+ def get_processor_name():
1271
+ """See base class."""
1272
+ return "CB"
1273
+
1274
+ def _create_examples_tfds(self, set_type):
1275
+ """Creates examples for the training/dev/test sets."""
1276
+ dataset = tfds.load(
1277
+ "super_glue/cb", split=set_type, try_gcs=True).as_numpy_iterator()
1278
+ examples = []
1279
+ for example in dataset:
1280
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1281
+ text_a = self.process_text_fn(example["premise"])
1282
+ text_b = self.process_text_fn(example["hypothesis"])
1283
+ label = "entailment"
1284
+ if set_type != "test":
1285
+ label = self.get_labels()[example["label"]]
1286
+ examples.append(
1287
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1288
+ return examples
1289
+
1290
+
1291
+ class SuperGLUERTEProcessor(DefaultGLUEDataProcessor):
1292
+ """Processor for the RTE dataset (SuperGLUE version)."""
1293
+
1294
+ def get_labels(self):
1295
+ """See base class."""
1296
+ # All datasets are converted to 2-class split, where for 3-class datasets we
1297
+ # collapse neutral and contradiction into not_entailment.
1298
+ return ["entailment", "not_entailment"]
1299
+
1300
+ @staticmethod
1301
+ def get_processor_name():
1302
+ """See base class."""
1303
+ return "RTESuperGLUE"
1304
+
1305
+ def _create_examples_tfds(self, set_type):
1306
+ """Creates examples for the training/dev/test sets."""
1307
+ examples = []
1308
+ dataset = tfds.load(
1309
+ "super_glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
1310
+ for example in dataset:
1311
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1312
+ text_a = self.process_text_fn(example["premise"])
1313
+ text_b = self.process_text_fn(example["hypothesis"])
1314
+ label = "entailment"
1315
+ if set_type != "test":
1316
+ label = self.get_labels()[example["label"]]
1317
+ examples.append(
1318
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
1319
+ return examples
1320
+
1321
+
1322
+ class WiCInputExample(InputExample):
1323
+ """Processor for the WiC dataset (SuperGLUE version)."""
1324
+
1325
+ def __init__(self,
1326
+ guid,
1327
+ text_a,
1328
+ text_b=None,
1329
+ label=None,
1330
+ word=None,
1331
+ weight=None,
1332
+ example_id=None):
1333
+ """A single training/test example for simple seq regression/classification."""
1334
+ super(WiCInputExample, self).__init__(guid, text_a, text_b, label, weight,
1335
+ example_id)
1336
+ self.word = word
1337
+
1338
+
1339
+ class WiCProcessor(DefaultGLUEDataProcessor):
1340
+ """Processor for the RTE dataset (SuperGLUE version)."""
1341
+
1342
+ def get_labels(self):
1343
+ """Not used."""
1344
+ return []
1345
+
1346
+ @staticmethod
1347
+ def get_processor_name():
1348
+ """See base class."""
1349
+ return "RTESuperGLUE"
1350
+
1351
+ def _create_examples_tfds(self, set_type):
1352
+ """Creates examples for the training/dev/test sets."""
1353
+ examples = []
1354
+ dataset = tfds.load(
1355
+ "super_glue/wic", split=set_type, try_gcs=True).as_numpy_iterator()
1356
+ for example in dataset:
1357
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
1358
+ text_a = self.process_text_fn(example["sentence1"])
1359
+ text_b = self.process_text_fn(example["sentence2"])
1360
+ word = self.process_text_fn(example["word"])
1361
+ label = 0
1362
+ if set_type != "test":
1363
+ label = example["label"]
1364
+ examples.append(
1365
+ WiCInputExample(
1366
+ guid=guid, text_a=text_a, text_b=text_b, word=word, label=label))
1367
+ return examples
1368
+
1369
+ def featurize_example(self, ex_index, example, label_list, max_seq_length,
1370
+ tokenizer):
1371
+ """Here we concate sentence1, sentence2, word together with [SEP] tokens."""
1372
+ del label_list
1373
+ tokens_a = tokenizer.tokenize(example.text_a)
1374
+ tokens_b = tokenizer.tokenize(example.text_b)
1375
+ tokens_word = tokenizer.tokenize(example.word)
1376
+
1377
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
1378
+ # length is less than the specified length.
1379
+ # Account for [CLS], [SEP], [SEP], [SEP] with "- 4"
1380
+ # Here we only pop out the first two sentence tokens.
1381
+ _truncate_seq_pair(tokens_a, tokens_b,
1382
+ max_seq_length - 4 - len(tokens_word))
1383
+
1384
+ seg_id_a = 0
1385
+ seg_id_b = 1
1386
+ seg_id_c = 2
1387
+ seg_id_cls = 0
1388
+ seg_id_pad = 0
1389
+
1390
+ tokens = []
1391
+ segment_ids = []
1392
+ tokens.append("[CLS]")
1393
+ segment_ids.append(seg_id_cls)
1394
+ for token in tokens_a:
1395
+ tokens.append(token)
1396
+ segment_ids.append(seg_id_a)
1397
+ tokens.append("[SEP]")
1398
+ segment_ids.append(seg_id_a)
1399
+
1400
+ for token in tokens_b:
1401
+ tokens.append(token)
1402
+ segment_ids.append(seg_id_b)
1403
+
1404
+ tokens.append("[SEP]")
1405
+ segment_ids.append(seg_id_b)
1406
+
1407
+ for token in tokens_word:
1408
+ tokens.append(token)
1409
+ segment_ids.append(seg_id_c)
1410
+
1411
+ tokens.append("[SEP]")
1412
+ segment_ids.append(seg_id_c)
1413
+
1414
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
1415
+
1416
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
1417
+ # tokens are attended to.
1418
+ input_mask = [1] * len(input_ids)
1419
+
1420
+ # Zero-pad up to the sequence length.
1421
+ while len(input_ids) < max_seq_length:
1422
+ input_ids.append(0)
1423
+ input_mask.append(0)
1424
+ segment_ids.append(seg_id_pad)
1425
+
1426
+ assert len(input_ids) == max_seq_length
1427
+ assert len(input_mask) == max_seq_length
1428
+ assert len(segment_ids) == max_seq_length
1429
+
1430
+ label_id = example.label
1431
+ if ex_index < 5:
1432
+ logging.info("*** Example ***")
1433
+ logging.info("guid: %s", (example.guid))
1434
+ logging.info("tokens: %s",
1435
+ " ".join([tokenization.printable_text(x) for x in tokens]))
1436
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
1437
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
1438
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
1439
+ logging.info("label: %s (id = %s)", example.label, str(label_id))
1440
+ logging.info("weight: %s", example.weight)
1441
+ logging.info("example_id: %s", example.example_id)
1442
+
1443
+ feature = InputFeatures(
1444
+ input_ids=input_ids,
1445
+ input_mask=input_mask,
1446
+ segment_ids=segment_ids,
1447
+ label_id=label_id,
1448
+ is_real_example=True,
1449
+ weight=example.weight,
1450
+ example_id=example.example_id)
1451
+
1452
+ return feature
1453
+
1454
+
1455
+ def file_based_convert_examples_to_features(examples,
1456
+ label_list,
1457
+ max_seq_length,
1458
+ tokenizer,
1459
+ output_file,
1460
+ label_type=None,
1461
+ featurize_fn=None):
1462
+ """Convert a set of `InputExample`s to a TFRecord file."""
1463
+
1464
+ tf.io.gfile.makedirs(os.path.dirname(output_file))
1465
+ writer = tf.io.TFRecordWriter(output_file)
1466
+
1467
+ for ex_index, example in enumerate(examples):
1468
+ if ex_index % 10000 == 0:
1469
+ logging.info("Writing example %d of %d", ex_index, len(examples))
1470
+
1471
+ if featurize_fn:
1472
+ feature = featurize_fn(ex_index, example, label_list, max_seq_length,
1473
+ tokenizer)
1474
+ else:
1475
+ feature = convert_single_example(ex_index, example, label_list,
1476
+ max_seq_length, tokenizer)
1477
+
1478
+ def create_int_feature(values):
1479
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
1480
+ return f
1481
+
1482
+ def create_float_feature(values):
1483
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
1484
+ return f
1485
+
1486
+ features = collections.OrderedDict()
1487
+ features["input_ids"] = create_int_feature(feature.input_ids)
1488
+ features["input_mask"] = create_int_feature(feature.input_mask)
1489
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
1490
+ if label_type is not None and label_type == float:
1491
+ features["label_ids"] = create_float_feature([feature.label_id])
1492
+ elif feature.label_id is not None:
1493
+ features["label_ids"] = create_int_feature([feature.label_id])
1494
+ features["is_real_example"] = create_int_feature(
1495
+ [int(feature.is_real_example)])
1496
+ if feature.weight is not None:
1497
+ features["weight"] = create_float_feature([feature.weight])
1498
+ if feature.example_id is not None:
1499
+ features["example_id"] = create_int_feature([feature.example_id])
1500
+ else:
1501
+ features["example_id"] = create_int_feature([ex_index])
1502
+
1503
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
1504
+ writer.write(tf_example.SerializeToString())
1505
+ writer.close()
1506
+
1507
+
1508
+ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1509
+ """Truncates a sequence pair in place to the maximum length."""
1510
+
1511
+ # This is a simple heuristic which will always truncate the longer sequence
1512
+ # one token at a time. This makes more sense than truncating an equal percent
1513
+ # of tokens from each, since if one sequence is very short then each token
1514
+ # that's truncated likely contains more information than a longer sequence.
1515
+ while True:
1516
+ total_length = len(tokens_a) + len(tokens_b)
1517
+ if total_length <= max_length:
1518
+ break
1519
+ if len(tokens_a) > len(tokens_b):
1520
+ tokens_a.pop()
1521
+ else:
1522
+ tokens_b.pop()
1523
+
1524
+
1525
+ def generate_tf_record_from_data_file(processor,
1526
+ data_dir,
1527
+ tokenizer,
1528
+ train_data_output_path=None,
1529
+ eval_data_output_path=None,
1530
+ test_data_output_path=None,
1531
+ max_seq_length=128):
1532
+ """Generates and saves training data into a tf record file.
1533
+
1534
+ Args:
1535
+ processor: Input processor object to be used for generating data. Subclass
1536
+ of `DataProcessor`.
1537
+ data_dir: Directory that contains train/eval/test data to process.
1538
+ tokenizer: The tokenizer to be applied on the data.
1539
+ train_data_output_path: Output to which processed tf record for training
1540
+ will be saved.
1541
+ eval_data_output_path: Output to which processed tf record for evaluation
1542
+ will be saved.
1543
+ test_data_output_path: Output to which processed tf record for testing
1544
+ will be saved. Must be a pattern template with {} if processor has
1545
+ language specific test data.
1546
+ max_seq_length: Maximum sequence length of the to be generated
1547
+ training/eval data.
1548
+
1549
+ Returns:
1550
+ A dictionary containing input meta data.
1551
+ """
1552
+ assert train_data_output_path or eval_data_output_path
1553
+
1554
+ label_list = processor.get_labels()
1555
+ label_type = getattr(processor, "label_type", None)
1556
+ is_regression = getattr(processor, "is_regression", False)
1557
+ has_sample_weights = getattr(processor, "weight_key", False)
1558
+
1559
+ num_training_data = 0
1560
+ if train_data_output_path:
1561
+ train_input_data_examples = processor.get_train_examples(data_dir)
1562
+ file_based_convert_examples_to_features(train_input_data_examples,
1563
+ label_list, max_seq_length,
1564
+ tokenizer, train_data_output_path,
1565
+ label_type,
1566
+ processor.featurize_example)
1567
+ num_training_data = len(train_input_data_examples)
1568
+
1569
+ if eval_data_output_path:
1570
+ eval_input_data_examples = processor.get_dev_examples(data_dir)
1571
+ file_based_convert_examples_to_features(eval_input_data_examples,
1572
+ label_list, max_seq_length,
1573
+ tokenizer, eval_data_output_path,
1574
+ label_type,
1575
+ processor.featurize_example)
1576
+
1577
+ meta_data = {
1578
+ "processor_type": processor.get_processor_name(),
1579
+ "train_data_size": num_training_data,
1580
+ "max_seq_length": max_seq_length,
1581
+ }
1582
+
1583
+ if test_data_output_path:
1584
+ test_input_data_examples = processor.get_test_examples(data_dir)
1585
+ if isinstance(test_input_data_examples, dict):
1586
+ for language, examples in test_input_data_examples.items():
1587
+ file_based_convert_examples_to_features(
1588
+ examples, label_list, max_seq_length, tokenizer,
1589
+ test_data_output_path.format(language), label_type,
1590
+ processor.featurize_example)
1591
+ meta_data["test_{}_data_size".format(language)] = len(examples)
1592
+ else:
1593
+ file_based_convert_examples_to_features(test_input_data_examples,
1594
+ label_list, max_seq_length,
1595
+ tokenizer, test_data_output_path,
1596
+ label_type,
1597
+ processor.featurize_example)
1598
+ meta_data["test_data_size"] = len(test_input_data_examples)
1599
+
1600
+ if is_regression:
1601
+ meta_data["task_type"] = "bert_regression"
1602
+ meta_data["label_type"] = {int: "int", float: "float"}[label_type]
1603
+ else:
1604
+ meta_data["task_type"] = "bert_classification"
1605
+ meta_data["num_labels"] = len(processor.get_labels())
1606
+ if has_sample_weights:
1607
+ meta_data["has_sample_weights"] = True
1608
+
1609
+ if eval_data_output_path:
1610
+ meta_data["eval_data_size"] = len(eval_input_data_examples)
1611
+
1612
+ return meta_data