Pradeep Kumar commited on
Commit
a39862f
·
verified ·
1 Parent(s): 4061739

Delete dual_encoder_dataloader.py

Browse files
Files changed (1) hide show
  1. dual_encoder_dataloader.py +0 -147
dual_encoder_dataloader.py DELETED
@@ -1,147 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Loads dataset for the dual encoder (retrieval) task."""
16
- import dataclasses
17
- import functools
18
- import itertools
19
- from typing import Iterable, Mapping, Optional, Tuple
20
-
21
- import tensorflow as tf, tf_keras
22
- import tensorflow_hub as hub
23
-
24
- from official.common import dataset_fn
25
- from official.core import config_definitions as cfg
26
- from official.core import input_reader
27
- from official.nlp.data import data_loader
28
- from official.nlp.data import data_loader_factory
29
- from official.nlp.modeling import layers
30
-
31
-
32
- @dataclasses.dataclass
33
- class DualEncoderDataConfig(cfg.DataConfig):
34
- """Data config for dual encoder task (tasks/dual_encoder)."""
35
- # Either set `input_path`...
36
- input_path: str = ''
37
- # ...or `tfds_name` and `tfds_split` to specify input.
38
- tfds_name: str = ''
39
- tfds_split: str = ''
40
- global_batch_size: int = 32
41
- # Either build preprocessing with Python code by specifying these values...
42
- vocab_file: str = ''
43
- lower_case: bool = True
44
- # ...or load preprocessing from a SavedModel at this location.
45
- preprocessing_hub_module_url: str = ''
46
-
47
- left_text_fields: Tuple[str] = ('left_input',)
48
- right_text_fields: Tuple[str] = ('right_input',)
49
- is_training: bool = True
50
- seq_length: int = 128
51
- file_type: str = 'tfrecord'
52
-
53
-
54
- @data_loader_factory.register_data_loader_cls(DualEncoderDataConfig)
55
- class DualEncoderDataLoader(data_loader.DataLoader):
56
- """A class to load dataset for dual encoder task (tasks/dual_encoder)."""
57
-
58
- def __init__(self, params):
59
- if bool(params.tfds_name) == bool(params.input_path):
60
- raise ValueError('Must specify either `tfds_name` and `tfds_split` '
61
- 'or `input_path`.')
62
- if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
63
- raise ValueError('Must specify exactly one of vocab_file (with matching '
64
- 'lower_case flag) or preprocessing_hub_module_url.')
65
- self._params = params
66
- self._seq_length = params.seq_length
67
- self._left_text_fields = params.left_text_fields
68
- self._right_text_fields = params.right_text_fields
69
-
70
- if params.preprocessing_hub_module_url:
71
- preprocessing_hub_module = hub.load(params.preprocessing_hub_module_url)
72
- self._tokenizer = preprocessing_hub_module.tokenize
73
- self._pack_inputs = functools.partial(
74
- preprocessing_hub_module.bert_pack_inputs,
75
- seq_length=params.seq_length)
76
- else:
77
- self._tokenizer = layers.BertTokenizer(
78
- vocab_file=params.vocab_file, lower_case=params.lower_case)
79
- self._pack_inputs = layers.BertPackInputs(
80
- seq_length=params.seq_length,
81
- special_tokens_dict=self._tokenizer.get_special_tokens_dict())
82
-
83
- def _decode(self, record: tf.Tensor):
84
- """Decodes a serialized tf.Example."""
85
- name_to_features = {
86
- x: tf.io.FixedLenFeature([], tf.string)
87
- for x in itertools.chain(
88
- *[self._left_text_fields, self._right_text_fields])
89
- }
90
- example = tf.io.parse_single_example(record, name_to_features)
91
-
92
- # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
93
- # So cast all int64 to int32.
94
- for name in example:
95
- t = example[name]
96
- if t.dtype == tf.int64:
97
- t = tf.cast(t, tf.int32)
98
- example[name] = t
99
-
100
- return example
101
-
102
- def _bert_tokenize(
103
- self, record: Mapping[str, tf.Tensor],
104
- text_fields: Iterable[str]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
105
- """Tokenize the input in text_fields using BERT tokenizer.
106
-
107
- Args:
108
- record: A tfexample record contains the features.
109
- text_fields: A list of fields to be tokenzied.
110
-
111
- Returns:
112
- The tokenized features in a tuple of (input_word_ids, input_mask,
113
- input_type_ids).
114
- """
115
- segments_text = [record[x] for x in text_fields]
116
- segments_tokens = [self._tokenizer(s) for s in segments_text]
117
- segments = [tf.cast(x.merge_dims(1, 2), tf.int32) for x in segments_tokens]
118
- return self._pack_inputs(segments)
119
-
120
- def _bert_preprocess(
121
- self, record: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
122
- """Perform the bert word piece tokenization for left and right inputs."""
123
-
124
- def _switch_prefix(string, old, new):
125
- if string.startswith(old): return new + string[len(old):]
126
- raise ValueError('Expected {} to start with {}'.format(string, old))
127
-
128
- def _switch_key_prefix(d, old, new):
129
- return {_switch_prefix(key, old, new): value for key, value in d.items()} # pytype: disable=attribute-error # trace-all-classes
130
-
131
- model_inputs = _switch_key_prefix(
132
- self._bert_tokenize(record, self._left_text_fields),
133
- 'input_', 'left_')
134
- model_inputs.update(_switch_key_prefix(
135
- self._bert_tokenize(record, self._right_text_fields),
136
- 'input_', 'right_'))
137
- return model_inputs
138
-
139
- def load(self, input_context: Optional[tf.distribute.InputContext] = None):
140
- """Returns a tf.dataset.Dataset."""
141
- reader = input_reader.InputReader(
142
- params=self._params,
143
- # Skip `decoder_fn` for tfds input.
144
- decoder_fn=self._decode if self._params.input_path else None,
145
- dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
146
- postprocess_fn=self._bert_preprocess)
147
- return reader.read(input_context)