Spaces:
Sleeping
Sleeping
Delete sentence_retrieval_lib.py
Browse files- sentence_retrieval_lib.py +0 -166
sentence_retrieval_lib.py
DELETED
@@ -1,166 +0,0 @@
|
|
1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
"""BERT library to process data for cross lingual sentence retrieval task."""
|
16 |
-
|
17 |
-
import os
|
18 |
-
|
19 |
-
from absl import logging
|
20 |
-
from official.nlp.data import classifier_data_lib
|
21 |
-
from official.nlp.tools import tokenization
|
22 |
-
|
23 |
-
|
24 |
-
class BuccProcessor(classifier_data_lib.DataProcessor):
|
25 |
-
"""Procssor for Xtreme BUCC data set."""
|
26 |
-
supported_languages = ["de", "fr", "ru", "zh"]
|
27 |
-
|
28 |
-
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
|
29 |
-
super(BuccProcessor, self).__init__(process_text_fn)
|
30 |
-
self.languages = BuccProcessor.supported_languages
|
31 |
-
|
32 |
-
def get_dev_examples(self, data_dir, file_pattern):
|
33 |
-
return self._create_examples(
|
34 |
-
self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
|
35 |
-
"sample")
|
36 |
-
|
37 |
-
def get_test_examples(self, data_dir, file_pattern):
|
38 |
-
return self._create_examples(
|
39 |
-
self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
|
40 |
-
"test")
|
41 |
-
|
42 |
-
@staticmethod
|
43 |
-
def get_processor_name():
|
44 |
-
"""See base class."""
|
45 |
-
return "BUCC"
|
46 |
-
|
47 |
-
def _create_examples(self, lines, set_type):
|
48 |
-
"""Creates examples for the training and dev sets."""
|
49 |
-
examples = []
|
50 |
-
for (i, line) in enumerate(lines):
|
51 |
-
guid = "%s-%s" % (set_type, i)
|
52 |
-
example_id = int(line[0].split("-")[1])
|
53 |
-
text_a = self.process_text_fn(line[1])
|
54 |
-
examples.append(
|
55 |
-
classifier_data_lib.InputExample(
|
56 |
-
guid=guid, text_a=text_a, example_id=example_id))
|
57 |
-
return examples
|
58 |
-
|
59 |
-
|
60 |
-
class TatoebaProcessor(classifier_data_lib.DataProcessor):
|
61 |
-
"""Procssor for Xtreme Tatoeba data set."""
|
62 |
-
supported_languages = [
|
63 |
-
"af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
|
64 |
-
"he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
|
65 |
-
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
|
66 |
-
]
|
67 |
-
|
68 |
-
def __init__(self, process_text_fn=tokenization.convert_to_unicode):
|
69 |
-
super(TatoebaProcessor, self).__init__(process_text_fn)
|
70 |
-
self.languages = TatoebaProcessor.supported_languages
|
71 |
-
|
72 |
-
def get_test_examples(self, data_dir, file_path):
|
73 |
-
return self._create_examples(
|
74 |
-
self._read_tsv(os.path.join(data_dir, file_path)), "test")
|
75 |
-
|
76 |
-
@staticmethod
|
77 |
-
def get_processor_name():
|
78 |
-
"""See base class."""
|
79 |
-
return "TATOEBA"
|
80 |
-
|
81 |
-
def _create_examples(self, lines, set_type):
|
82 |
-
"""Creates examples for the training and dev sets."""
|
83 |
-
examples = []
|
84 |
-
for (i, line) in enumerate(lines):
|
85 |
-
guid = "%s-%s" % (set_type, i)
|
86 |
-
text_a = self.process_text_fn(line[0])
|
87 |
-
examples.append(
|
88 |
-
classifier_data_lib.InputExample(
|
89 |
-
guid=guid, text_a=text_a, example_id=i))
|
90 |
-
return examples
|
91 |
-
|
92 |
-
|
93 |
-
def generate_sentence_retrevial_tf_record(processor,
|
94 |
-
data_dir,
|
95 |
-
tokenizer,
|
96 |
-
eval_data_output_path=None,
|
97 |
-
test_data_output_path=None,
|
98 |
-
max_seq_length=128):
|
99 |
-
"""Generates the tf records for retrieval tasks.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
processor: Input processor object to be used for generating data. Subclass
|
103 |
-
of `DataProcessor`.
|
104 |
-
data_dir: Directory that contains train/eval data to process. Data files
|
105 |
-
should be in from.
|
106 |
-
tokenizer: The tokenizer to be applied on the data.
|
107 |
-
eval_data_output_path: Output to which processed tf record for evaluation
|
108 |
-
will be saved.
|
109 |
-
test_data_output_path: Output to which processed tf record for testing
|
110 |
-
will be saved. Must be a pattern template with {} if processor has
|
111 |
-
language specific test data.
|
112 |
-
max_seq_length: Maximum sequence length of the to be generated
|
113 |
-
training/eval data.
|
114 |
-
|
115 |
-
Returns:
|
116 |
-
A dictionary containing input meta data.
|
117 |
-
"""
|
118 |
-
assert eval_data_output_path or test_data_output_path
|
119 |
-
|
120 |
-
if processor.get_processor_name() == "BUCC":
|
121 |
-
path_pattern = "{}-en.{{}}.{}"
|
122 |
-
|
123 |
-
if processor.get_processor_name() == "TATOEBA":
|
124 |
-
path_pattern = "{}-en.{}"
|
125 |
-
|
126 |
-
meta_data = {
|
127 |
-
"processor_type": processor.get_processor_name(),
|
128 |
-
"max_seq_length": max_seq_length,
|
129 |
-
"number_eval_data": {},
|
130 |
-
"number_test_data": {},
|
131 |
-
}
|
132 |
-
logging.info("Start to process %s task data", processor.get_processor_name())
|
133 |
-
|
134 |
-
for lang_a in processor.languages:
|
135 |
-
for lang_b in [lang_a, "en"]:
|
136 |
-
if eval_data_output_path:
|
137 |
-
eval_input_data_examples = processor.get_dev_examples(
|
138 |
-
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
|
139 |
-
|
140 |
-
num_eval_data = len(eval_input_data_examples)
|
141 |
-
logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
|
142 |
-
lang_a, lang_b)
|
143 |
-
output_file = os.path.join(
|
144 |
-
eval_data_output_path,
|
145 |
-
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
|
146 |
-
classifier_data_lib.file_based_convert_examples_to_features(
|
147 |
-
eval_input_data_examples, None, max_seq_length, tokenizer,
|
148 |
-
output_file, None)
|
149 |
-
meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
|
150 |
-
|
151 |
-
if test_data_output_path:
|
152 |
-
test_input_data_examples = processor.get_test_examples(
|
153 |
-
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
|
154 |
-
|
155 |
-
num_test_data = len(test_input_data_examples)
|
156 |
-
logging.info("Processing %d test examples of %s-en.%s", num_test_data,
|
157 |
-
lang_a, lang_b)
|
158 |
-
output_file = os.path.join(
|
159 |
-
test_data_output_path,
|
160 |
-
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
|
161 |
-
classifier_data_lib.file_based_convert_examples_to_features(
|
162 |
-
test_input_data_examples, None, max_seq_length, tokenizer,
|
163 |
-
output_file, None)
|
164 |
-
meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
|
165 |
-
|
166 |
-
return meta_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|