Pradeep Kumar commited on
Commit
b1cf90a
·
verified ·
1 Parent(s): 535cfb6

Delete dual_encoder_dataloader_test.py

Browse files
Files changed (1) hide show
  1. dual_encoder_dataloader_test.py +0 -131
dual_encoder_dataloader_test.py DELETED
@@ -1,131 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Tests for official.nlp.data.dual_encoder_dataloader."""
16
- import os
17
-
18
- from absl.testing import parameterized
19
- import tensorflow as tf, tf_keras
20
-
21
- from official.nlp.data import dual_encoder_dataloader
22
-
23
-
24
- _LEFT_FEATURE_NAME = 'left_input'
25
- _RIGHT_FEATURE_NAME = 'right_input'
26
-
27
-
28
- def _create_fake_dataset(output_path):
29
- """Creates a fake dataset contains examples for training a dual encoder model.
30
-
31
- The created dataset contains examples with two byteslist features keyed by
32
- _LEFT_FEATURE_NAME and _RIGHT_FEATURE_NAME.
33
-
34
- Args:
35
- output_path: The output path of the fake dataset.
36
- """
37
- def create_str_feature(values):
38
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
39
-
40
- with tf.io.TFRecordWriter(output_path) as writer:
41
- for _ in range(100):
42
- features = {}
43
- features[_LEFT_FEATURE_NAME] = create_str_feature([b'hello world.'])
44
- features[_RIGHT_FEATURE_NAME] = create_str_feature([b'world hello.'])
45
-
46
- tf_example = tf.train.Example(
47
- features=tf.train.Features(feature=features))
48
- writer.write(tf_example.SerializeToString())
49
-
50
-
51
- def _make_vocab_file(vocab, output_path):
52
- with tf.io.gfile.GFile(output_path, 'w') as f:
53
- f.write('\n'.join(vocab + ['']))
54
-
55
-
56
- class DualEncoderDataTest(tf.test.TestCase, parameterized.TestCase):
57
-
58
- def test_load_dataset(self):
59
- seq_length = 16
60
- batch_size = 10
61
- train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
62
- vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
63
-
64
- _create_fake_dataset(train_data_path)
65
- _make_vocab_file(
66
- ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'], vocab_path)
67
-
68
- data_config = dual_encoder_dataloader.DualEncoderDataConfig(
69
- input_path=train_data_path,
70
- seq_length=seq_length,
71
- vocab_file=vocab_path,
72
- lower_case=True,
73
- left_text_fields=(_LEFT_FEATURE_NAME,),
74
- right_text_fields=(_RIGHT_FEATURE_NAME,),
75
- global_batch_size=batch_size)
76
- dataset = dual_encoder_dataloader.DualEncoderDataLoader(
77
- data_config).load()
78
- features = next(iter(dataset))
79
- self.assertCountEqual(
80
- ['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
81
- 'right_mask', 'right_type_ids'],
82
- features.keys())
83
- self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
84
- self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
85
- self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
86
- self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
87
- self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
88
- self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
89
-
90
- @parameterized.parameters(False, True)
91
- def test_load_tfds(self, use_preprocessing_hub):
92
- seq_length = 16
93
- batch_size = 10
94
- if use_preprocessing_hub:
95
- vocab_path = ''
96
- preprocessing_hub = (
97
- 'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3')
98
- else:
99
- vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
100
- _make_vocab_file(
101
- ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
102
- vocab_path)
103
- preprocessing_hub = ''
104
-
105
- data_config = dual_encoder_dataloader.DualEncoderDataConfig(
106
- tfds_name='para_crawl/enmt',
107
- tfds_split='train',
108
- seq_length=seq_length,
109
- vocab_file=vocab_path,
110
- lower_case=True,
111
- left_text_fields=('en',),
112
- right_text_fields=('mt',),
113
- preprocessing_hub_module_url=preprocessing_hub,
114
- global_batch_size=batch_size)
115
- dataset = dual_encoder_dataloader.DualEncoderDataLoader(
116
- data_config).load()
117
- features = next(iter(dataset))
118
- self.assertCountEqual(
119
- ['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
120
- 'right_mask', 'right_type_ids'],
121
- features.keys())
122
- self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
123
- self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
124
- self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
125
- self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
126
- self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
127
- self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
128
-
129
-
130
- if __name__ == '__main__':
131
- tf.test.main()