Spaces:
Running
Running
Pradeep Kumar
commited on
Delete sentence_prediction_dataloader.py
Browse files
sentence_prediction_dataloader.py
DELETED
@@ -1,267 +0,0 @@
|
|
1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
"""Loads dataset for the sentence prediction (classification) task."""
|
16 |
-
import dataclasses
|
17 |
-
import functools
|
18 |
-
from typing import List, Mapping, Optional, Tuple
|
19 |
-
|
20 |
-
import tensorflow as tf, tf_keras
|
21 |
-
import tensorflow_hub as hub
|
22 |
-
|
23 |
-
from official.common import dataset_fn
|
24 |
-
from official.core import config_definitions as cfg
|
25 |
-
from official.core import input_reader
|
26 |
-
from official.nlp import modeling
|
27 |
-
from official.nlp.data import data_loader
|
28 |
-
from official.nlp.data import data_loader_factory
|
29 |
-
|
30 |
-
LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
|
31 |
-
|
32 |
-
|
33 |
-
@dataclasses.dataclass
|
34 |
-
class SentencePredictionDataConfig(cfg.DataConfig):
|
35 |
-
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
|
36 |
-
input_path: str = ''
|
37 |
-
global_batch_size: int = 32
|
38 |
-
is_training: bool = True
|
39 |
-
seq_length: int = 128
|
40 |
-
label_type: str = 'int'
|
41 |
-
# Whether to include the example id number.
|
42 |
-
include_example_id: bool = False
|
43 |
-
label_field: str = 'label_ids'
|
44 |
-
# Maps the key in TfExample to feature name.
|
45 |
-
# E.g 'label_ids' to 'next_sentence_labels'
|
46 |
-
label_name: Optional[Tuple[str, str]] = None
|
47 |
-
# Either tfrecord, sstable, or recordio.
|
48 |
-
file_type: str = 'tfrecord'
|
49 |
-
|
50 |
-
|
51 |
-
@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
|
52 |
-
class SentencePredictionDataLoader(data_loader.DataLoader):
|
53 |
-
"""A class to load dataset for sentence prediction (classification) task."""
|
54 |
-
|
55 |
-
def __init__(self, params):
|
56 |
-
self._params = params
|
57 |
-
self._seq_length = params.seq_length
|
58 |
-
self._include_example_id = params.include_example_id
|
59 |
-
self._label_field = params.label_field
|
60 |
-
if params.label_name:
|
61 |
-
self._label_name_mapping = dict([params.label_name])
|
62 |
-
else:
|
63 |
-
self._label_name_mapping = dict()
|
64 |
-
|
65 |
-
def name_to_features_spec(self):
|
66 |
-
"""Defines features to decode. Subclass may override to append features."""
|
67 |
-
label_type = LABEL_TYPES_MAP[self._params.label_type]
|
68 |
-
name_to_features = {
|
69 |
-
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
70 |
-
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
71 |
-
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
|
72 |
-
self._label_field: tf.io.FixedLenFeature([], label_type),
|
73 |
-
}
|
74 |
-
if self._include_example_id:
|
75 |
-
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
|
76 |
-
|
77 |
-
return name_to_features
|
78 |
-
|
79 |
-
def _decode(self, record: tf.Tensor):
|
80 |
-
"""Decodes a serialized tf.Example."""
|
81 |
-
example = tf.io.parse_single_example(record, self.name_to_features_spec())
|
82 |
-
|
83 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
84 |
-
# So cast all int64 to int32.
|
85 |
-
for name in example:
|
86 |
-
t = example[name]
|
87 |
-
if t.dtype == tf.int64:
|
88 |
-
t = tf.cast(t, tf.int32)
|
89 |
-
example[name] = t
|
90 |
-
|
91 |
-
return example
|
92 |
-
|
93 |
-
def _parse(self, record: Mapping[str, tf.Tensor]):
|
94 |
-
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
|
95 |
-
key_mapping = {
|
96 |
-
'input_ids': 'input_word_ids',
|
97 |
-
'input_mask': 'input_mask',
|
98 |
-
'segment_ids': 'input_type_ids'
|
99 |
-
}
|
100 |
-
ret = {}
|
101 |
-
for record_key in record:
|
102 |
-
if record_key in key_mapping:
|
103 |
-
ret[key_mapping[record_key]] = record[record_key]
|
104 |
-
else:
|
105 |
-
ret[record_key] = record[record_key]
|
106 |
-
|
107 |
-
if self._label_field in self._label_name_mapping:
|
108 |
-
ret[self._label_name_mapping[self._label_field]] = record[
|
109 |
-
self._label_field]
|
110 |
-
|
111 |
-
return ret
|
112 |
-
|
113 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
114 |
-
"""Returns a tf.dataset.Dataset."""
|
115 |
-
reader = input_reader.InputReader(
|
116 |
-
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
117 |
-
params=self._params,
|
118 |
-
decoder_fn=self._decode,
|
119 |
-
parser_fn=self._parse)
|
120 |
-
return reader.read(input_context)
|
121 |
-
|
122 |
-
|
123 |
-
@dataclasses.dataclass
|
124 |
-
class SentencePredictionTextDataConfig(cfg.DataConfig):
|
125 |
-
"""Data config for sentence prediction task with raw text."""
|
126 |
-
# Either set `input_path`...
|
127 |
-
input_path: str = ''
|
128 |
-
# Either `int` or `float`.
|
129 |
-
label_type: str = 'int'
|
130 |
-
# ...or `tfds_name` and `tfds_split` to specify input.
|
131 |
-
tfds_name: str = ''
|
132 |
-
tfds_split: str = ''
|
133 |
-
# The name of the text feature fields. The text features will be
|
134 |
-
# concatenated in order.
|
135 |
-
text_fields: Optional[List[str]] = None
|
136 |
-
label_field: str = 'label'
|
137 |
-
global_batch_size: int = 32
|
138 |
-
seq_length: int = 128
|
139 |
-
is_training: bool = True
|
140 |
-
# Either build preprocessing with Python code by specifying these values
|
141 |
-
# for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
|
142 |
-
tokenization: str = 'WordPiece' # WordPiece or SentencePiece
|
143 |
-
# Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
|
144 |
-
# file if tokenization is SentencePiece.
|
145 |
-
vocab_file: str = ''
|
146 |
-
lower_case: bool = True
|
147 |
-
# ...or load preprocessing from a SavedModel at this location.
|
148 |
-
preprocessing_hub_module_url: str = ''
|
149 |
-
# Either tfrecord or sstsable or recordio.
|
150 |
-
file_type: str = 'tfrecord'
|
151 |
-
include_example_id: bool = False
|
152 |
-
|
153 |
-
|
154 |
-
class TextProcessor(tf.Module):
|
155 |
-
"""Text features processing for sentence prediction task."""
|
156 |
-
|
157 |
-
def __init__(self,
|
158 |
-
seq_length: int,
|
159 |
-
vocab_file: Optional[str] = None,
|
160 |
-
tokenization: Optional[str] = None,
|
161 |
-
lower_case: Optional[bool] = True,
|
162 |
-
preprocessing_hub_module_url: Optional[str] = None):
|
163 |
-
if preprocessing_hub_module_url:
|
164 |
-
self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
|
165 |
-
self._tokenizer = self._preprocessing_hub_module.tokenize
|
166 |
-
self._pack_inputs = functools.partial(
|
167 |
-
self._preprocessing_hub_module.bert_pack_inputs,
|
168 |
-
seq_length=seq_length)
|
169 |
-
return
|
170 |
-
|
171 |
-
if tokenization == 'WordPiece':
|
172 |
-
self._tokenizer = modeling.layers.BertTokenizer(
|
173 |
-
vocab_file=vocab_file, lower_case=lower_case)
|
174 |
-
elif tokenization == 'SentencePiece':
|
175 |
-
self._tokenizer = modeling.layers.SentencepieceTokenizer(
|
176 |
-
model_file_path=vocab_file,
|
177 |
-
lower_case=lower_case,
|
178 |
-
strip_diacritics=True) # Strip diacritics to follow ALBERT model
|
179 |
-
else:
|
180 |
-
raise ValueError('Unsupported tokenization: %s' % tokenization)
|
181 |
-
|
182 |
-
self._pack_inputs = modeling.layers.BertPackInputs(
|
183 |
-
seq_length=seq_length,
|
184 |
-
special_tokens_dict=self._tokenizer.get_special_tokens_dict())
|
185 |
-
|
186 |
-
def __call__(self, segments):
|
187 |
-
segments = [self._tokenizer(s) for s in segments]
|
188 |
-
# BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
|
189 |
-
# and SentencepieceTokenizer returns a RaggedTensor with shape
|
190 |
-
# [batch, sentencepiece],
|
191 |
-
segments = [
|
192 |
-
tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
|
193 |
-
for x in segments
|
194 |
-
]
|
195 |
-
return self._pack_inputs(segments)
|
196 |
-
|
197 |
-
|
198 |
-
@data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
|
199 |
-
class SentencePredictionTextDataLoader(data_loader.DataLoader):
|
200 |
-
"""Loads dataset with raw text for sentence prediction task."""
|
201 |
-
|
202 |
-
def __init__(self, params):
|
203 |
-
if bool(params.tfds_name) != bool(params.tfds_split):
|
204 |
-
raise ValueError('`tfds_name` and `tfds_split` should be specified or '
|
205 |
-
'unspecified at the same time.')
|
206 |
-
if bool(params.tfds_name) == bool(params.input_path):
|
207 |
-
raise ValueError('Must specify either `tfds_name` and `tfds_split` '
|
208 |
-
'or `input_path`.')
|
209 |
-
if not params.text_fields:
|
210 |
-
raise ValueError('Unexpected empty text fields.')
|
211 |
-
if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
|
212 |
-
raise ValueError('Must specify exactly one of vocab_file (with matching '
|
213 |
-
'lower_case flag) or preprocessing_hub_module_url.')
|
214 |
-
|
215 |
-
self._params = params
|
216 |
-
self._text_fields = params.text_fields
|
217 |
-
self._label_field = params.label_field
|
218 |
-
self._label_type = params.label_type
|
219 |
-
self._include_example_id = params.include_example_id
|
220 |
-
self._text_processor = TextProcessor(
|
221 |
-
seq_length=params.seq_length,
|
222 |
-
vocab_file=params.vocab_file,
|
223 |
-
tokenization=params.tokenization,
|
224 |
-
lower_case=params.lower_case,
|
225 |
-
preprocessing_hub_module_url=params.preprocessing_hub_module_url)
|
226 |
-
|
227 |
-
def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
|
228 |
-
"""Berts preprocess."""
|
229 |
-
segments = [record[x] for x in self._text_fields]
|
230 |
-
model_inputs = self._text_processor(segments)
|
231 |
-
for key in record:
|
232 |
-
if key not in self._text_fields:
|
233 |
-
model_inputs[key] = record[key]
|
234 |
-
return model_inputs
|
235 |
-
|
236 |
-
def name_to_features_spec(self):
|
237 |
-
name_to_features = {}
|
238 |
-
for text_field in self._text_fields:
|
239 |
-
name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
|
240 |
-
|
241 |
-
label_type = LABEL_TYPES_MAP[self._label_type]
|
242 |
-
name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
|
243 |
-
if self._include_example_id:
|
244 |
-
name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
|
245 |
-
return name_to_features
|
246 |
-
|
247 |
-
def _decode(self, record: tf.Tensor):
|
248 |
-
"""Decodes a serialized tf.Example."""
|
249 |
-
example = tf.io.parse_single_example(record, self.name_to_features_spec())
|
250 |
-
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
|
251 |
-
# So cast all int64 to int32.
|
252 |
-
for name in example:
|
253 |
-
t = example[name]
|
254 |
-
if t.dtype == tf.int64:
|
255 |
-
t = tf.cast(t, tf.int32)
|
256 |
-
example[name] = t
|
257 |
-
|
258 |
-
return example
|
259 |
-
|
260 |
-
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
|
261 |
-
"""Returns a tf.dataset.Dataset."""
|
262 |
-
reader = input_reader.InputReader(
|
263 |
-
dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
|
264 |
-
decoder_fn=self._decode if self._params.input_path else None,
|
265 |
-
params=self._params,
|
266 |
-
postprocess_fn=self._bert_preprocess)
|
267 |
-
return reader.read(input_context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|