File size: 6,413 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BERT library to process data for cross lingual sentence retrieval task."""

import os

from absl import logging
from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization


class BuccProcessor(classifier_data_lib.DataProcessor):
  """Procssor for Xtreme BUCC data set."""
  supported_languages = ["de", "fr", "ru", "zh"]

  def __init__(self, process_text_fn=tokenization.convert_to_unicode):
    super(BuccProcessor, self).__init__(process_text_fn)
    self.languages = BuccProcessor.supported_languages

  def get_dev_examples(self, data_dir, file_pattern):
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
        "sample")

  def get_test_examples(self, data_dir, file_pattern):
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
        "test")

  @staticmethod
  def get_processor_name():
    """See base class."""
    return "BUCC"

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      guid = "%s-%s" % (set_type, i)
      example_id = int(line[0].split("-")[1])
      text_a = self.process_text_fn(line[1])
      examples.append(
          classifier_data_lib.InputExample(
              guid=guid, text_a=text_a, example_id=example_id))
    return examples


class TatoebaProcessor(classifier_data_lib.DataProcessor):
  """Procssor for Xtreme Tatoeba data set."""
  supported_languages = [
      "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
      "he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
      "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
  ]

  def __init__(self, process_text_fn=tokenization.convert_to_unicode):
    super(TatoebaProcessor, self).__init__(process_text_fn)
    self.languages = TatoebaProcessor.supported_languages

  def get_test_examples(self, data_dir, file_path):
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, file_path)), "test")

  @staticmethod
  def get_processor_name():
    """See base class."""
    return "TATOEBA"

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      guid = "%s-%s" % (set_type, i)
      text_a = self.process_text_fn(line[0])
      examples.append(
          classifier_data_lib.InputExample(
              guid=guid, text_a=text_a, example_id=i))
    return examples


def generate_sentence_retrevial_tf_record(processor,
                                          data_dir,
                                          tokenizer,
                                          eval_data_output_path=None,
                                          test_data_output_path=None,
                                          max_seq_length=128):
  """Generates the tf records for retrieval tasks.

  Args:
    processor: Input processor object to be used for generating data. Subclass
      of `DataProcessor`.
      data_dir: Directory that contains train/eval data to process. Data files
        should be in from.
      tokenizer: The tokenizer to be applied on the data.
      eval_data_output_path: Output to which processed tf record for evaluation
        will be saved.
      test_data_output_path: Output to which processed tf record for testing
        will be saved. Must be a pattern template with {} if processor has
        language specific test data.
      max_seq_length: Maximum sequence length of the to be generated
        training/eval data.

  Returns:
      A dictionary containing input meta data.
  """
  assert eval_data_output_path or test_data_output_path

  if processor.get_processor_name() == "BUCC":
    path_pattern = "{}-en.{{}}.{}"

  if processor.get_processor_name() == "TATOEBA":
    path_pattern = "{}-en.{}"

  meta_data = {
      "processor_type": processor.get_processor_name(),
      "max_seq_length": max_seq_length,
      "number_eval_data": {},
      "number_test_data": {},
  }
  logging.info("Start to process %s task data", processor.get_processor_name())

  for lang_a in processor.languages:
    for lang_b in [lang_a, "en"]:
      if eval_data_output_path:
        eval_input_data_examples = processor.get_dev_examples(
            data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))

        num_eval_data = len(eval_input_data_examples)
        logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
                     lang_a, lang_b)
        output_file = os.path.join(
            eval_data_output_path,
            "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
        classifier_data_lib.file_based_convert_examples_to_features(
            eval_input_data_examples, None, max_seq_length, tokenizer,
            output_file, None)
        meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data

      if test_data_output_path:
        test_input_data_examples = processor.get_test_examples(
            data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))

        num_test_data = len(test_input_data_examples)
        logging.info("Processing %d test examples of %s-en.%s", num_test_data,
                     lang_a, lang_b)
        output_file = os.path.join(
            test_data_output_path,
            "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
        classifier_data_lib.file_based_convert_examples_to_features(
            test_input_data_examples, None, max_seq_length, tokenizer,
            output_file, None)
        meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data

  return meta_data