Pradeep Kumar commited on
Commit
31f84b4
·
verified ·
1 Parent(s): 3dad389

Delete sentence_prediction_dataloader.py

Browse files
Files changed (1) hide show
  1. sentence_prediction_dataloader.py +0 -267
sentence_prediction_dataloader.py DELETED
@@ -1,267 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Loads dataset for the sentence prediction (classification) task."""
16
- import dataclasses
17
- import functools
18
- from typing import List, Mapping, Optional, Tuple
19
-
20
- import tensorflow as tf, tf_keras
21
- import tensorflow_hub as hub
22
-
23
- from official.common import dataset_fn
24
- from official.core import config_definitions as cfg
25
- from official.core import input_reader
26
- from official.nlp import modeling
27
- from official.nlp.data import data_loader
28
- from official.nlp.data import data_loader_factory
29
-
30
- LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
31
-
32
-
33
- @dataclasses.dataclass
34
- class SentencePredictionDataConfig(cfg.DataConfig):
35
- """Data config for sentence prediction task (tasks/sentence_prediction)."""
36
- input_path: str = ''
37
- global_batch_size: int = 32
38
- is_training: bool = True
39
- seq_length: int = 128
40
- label_type: str = 'int'
41
- # Whether to include the example id number.
42
- include_example_id: bool = False
43
- label_field: str = 'label_ids'
44
- # Maps the key in TfExample to feature name.
45
- # E.g 'label_ids' to 'next_sentence_labels'
46
- label_name: Optional[Tuple[str, str]] = None
47
- # Either tfrecord, sstable, or recordio.
48
- file_type: str = 'tfrecord'
49
-
50
-
51
- @data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
52
- class SentencePredictionDataLoader(data_loader.DataLoader):
53
- """A class to load dataset for sentence prediction (classification) task."""
54
-
55
- def __init__(self, params):
56
- self._params = params
57
- self._seq_length = params.seq_length
58
- self._include_example_id = params.include_example_id
59
- self._label_field = params.label_field
60
- if params.label_name:
61
- self._label_name_mapping = dict([params.label_name])
62
- else:
63
- self._label_name_mapping = dict()
64
-
65
- def name_to_features_spec(self):
66
- """Defines features to decode. Subclass may override to append features."""
67
- label_type = LABEL_TYPES_MAP[self._params.label_type]
68
- name_to_features = {
69
- 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
70
- 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
71
- 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
72
- self._label_field: tf.io.FixedLenFeature([], label_type),
73
- }
74
- if self._include_example_id:
75
- name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
76
-
77
- return name_to_features
78
-
79
- def _decode(self, record: tf.Tensor):
80
- """Decodes a serialized tf.Example."""
81
- example = tf.io.parse_single_example(record, self.name_to_features_spec())
82
-
83
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
84
- # So cast all int64 to int32.
85
- for name in example:
86
- t = example[name]
87
- if t.dtype == tf.int64:
88
- t = tf.cast(t, tf.int32)
89
- example[name] = t
90
-
91
- return example
92
-
93
- def _parse(self, record: Mapping[str, tf.Tensor]):
94
- """Parses raw tensors into a dict of tensors to be consumed by the model."""
95
- key_mapping = {
96
- 'input_ids': 'input_word_ids',
97
- 'input_mask': 'input_mask',
98
- 'segment_ids': 'input_type_ids'
99
- }
100
- ret = {}
101
- for record_key in record:
102
- if record_key in key_mapping:
103
- ret[key_mapping[record_key]] = record[record_key]
104
- else:
105
- ret[record_key] = record[record_key]
106
-
107
- if self._label_field in self._label_name_mapping:
108
- ret[self._label_name_mapping[self._label_field]] = record[
109
- self._label_field]
110
-
111
- return ret
112
-
113
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
114
- """Returns a tf.dataset.Dataset."""
115
- reader = input_reader.InputReader(
116
- dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
117
- params=self._params,
118
- decoder_fn=self._decode,
119
- parser_fn=self._parse)
120
- return reader.read(input_context)
121
-
122
-
123
- @dataclasses.dataclass
124
- class SentencePredictionTextDataConfig(cfg.DataConfig):
125
- """Data config for sentence prediction task with raw text."""
126
- # Either set `input_path`...
127
- input_path: str = ''
128
- # Either `int` or `float`.
129
- label_type: str = 'int'
130
- # ...or `tfds_name` and `tfds_split` to specify input.
131
- tfds_name: str = ''
132
- tfds_split: str = ''
133
- # The name of the text feature fields. The text features will be
134
- # concatenated in order.
135
- text_fields: Optional[List[str]] = None
136
- label_field: str = 'label'
137
- global_batch_size: int = 32
138
- seq_length: int = 128
139
- is_training: bool = True
140
- # Either build preprocessing with Python code by specifying these values
141
- # for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
142
- tokenization: str = 'WordPiece' # WordPiece or SentencePiece
143
- # Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
144
- # file if tokenization is SentencePiece.
145
- vocab_file: str = ''
146
- lower_case: bool = True
147
- # ...or load preprocessing from a SavedModel at this location.
148
- preprocessing_hub_module_url: str = ''
149
- # Either tfrecord or sstsable or recordio.
150
- file_type: str = 'tfrecord'
151
- include_example_id: bool = False
152
-
153
-
154
- class TextProcessor(tf.Module):
155
- """Text features processing for sentence prediction task."""
156
-
157
- def __init__(self,
158
- seq_length: int,
159
- vocab_file: Optional[str] = None,
160
- tokenization: Optional[str] = None,
161
- lower_case: Optional[bool] = True,
162
- preprocessing_hub_module_url: Optional[str] = None):
163
- if preprocessing_hub_module_url:
164
- self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
165
- self._tokenizer = self._preprocessing_hub_module.tokenize
166
- self._pack_inputs = functools.partial(
167
- self._preprocessing_hub_module.bert_pack_inputs,
168
- seq_length=seq_length)
169
- return
170
-
171
- if tokenization == 'WordPiece':
172
- self._tokenizer = modeling.layers.BertTokenizer(
173
- vocab_file=vocab_file, lower_case=lower_case)
174
- elif tokenization == 'SentencePiece':
175
- self._tokenizer = modeling.layers.SentencepieceTokenizer(
176
- model_file_path=vocab_file,
177
- lower_case=lower_case,
178
- strip_diacritics=True) # Strip diacritics to follow ALBERT model
179
- else:
180
- raise ValueError('Unsupported tokenization: %s' % tokenization)
181
-
182
- self._pack_inputs = modeling.layers.BertPackInputs(
183
- seq_length=seq_length,
184
- special_tokens_dict=self._tokenizer.get_special_tokens_dict())
185
-
186
- def __call__(self, segments):
187
- segments = [self._tokenizer(s) for s in segments]
188
- # BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
189
- # and SentencepieceTokenizer returns a RaggedTensor with shape
190
- # [batch, sentencepiece],
191
- segments = [
192
- tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
193
- for x in segments
194
- ]
195
- return self._pack_inputs(segments)
196
-
197
-
198
- @data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
199
- class SentencePredictionTextDataLoader(data_loader.DataLoader):
200
- """Loads dataset with raw text for sentence prediction task."""
201
-
202
- def __init__(self, params):
203
- if bool(params.tfds_name) != bool(params.tfds_split):
204
- raise ValueError('`tfds_name` and `tfds_split` should be specified or '
205
- 'unspecified at the same time.')
206
- if bool(params.tfds_name) == bool(params.input_path):
207
- raise ValueError('Must specify either `tfds_name` and `tfds_split` '
208
- 'or `input_path`.')
209
- if not params.text_fields:
210
- raise ValueError('Unexpected empty text fields.')
211
- if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
212
- raise ValueError('Must specify exactly one of vocab_file (with matching '
213
- 'lower_case flag) or preprocessing_hub_module_url.')
214
-
215
- self._params = params
216
- self._text_fields = params.text_fields
217
- self._label_field = params.label_field
218
- self._label_type = params.label_type
219
- self._include_example_id = params.include_example_id
220
- self._text_processor = TextProcessor(
221
- seq_length=params.seq_length,
222
- vocab_file=params.vocab_file,
223
- tokenization=params.tokenization,
224
- lower_case=params.lower_case,
225
- preprocessing_hub_module_url=params.preprocessing_hub_module_url)
226
-
227
- def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
228
- """Berts preprocess."""
229
- segments = [record[x] for x in self._text_fields]
230
- model_inputs = self._text_processor(segments)
231
- for key in record:
232
- if key not in self._text_fields:
233
- model_inputs[key] = record[key]
234
- return model_inputs
235
-
236
- def name_to_features_spec(self):
237
- name_to_features = {}
238
- for text_field in self._text_fields:
239
- name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
240
-
241
- label_type = LABEL_TYPES_MAP[self._label_type]
242
- name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
243
- if self._include_example_id:
244
- name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
245
- return name_to_features
246
-
247
- def _decode(self, record: tf.Tensor):
248
- """Decodes a serialized tf.Example."""
249
- example = tf.io.parse_single_example(record, self.name_to_features_spec())
250
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
251
- # So cast all int64 to int32.
252
- for name in example:
253
- t = example[name]
254
- if t.dtype == tf.int64:
255
- t = tf.cast(t, tf.int32)
256
- example[name] = t
257
-
258
- return example
259
-
260
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
261
- """Returns a tf.dataset.Dataset."""
262
- reader = input_reader.InputReader(
263
- dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
264
- decoder_fn=self._decode if self._params.input_path else None,
265
- params=self._params,
266
- postprocess_fn=self._bert_preprocess)
267
- return reader.read(input_context)