Pradeep Kumar commited on
Commit
7a04101
·
verified ·
1 Parent(s): 0b89b7a

Delete wmt_dataloader_test.py

Browse files
Files changed (1) hide show
  1. wmt_dataloader_test.py +0 -130
wmt_dataloader_test.py DELETED
@@ -1,130 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Tests for official.nlp.data.wmt_dataloader."""
16
- import os
17
- from absl.testing import parameterized
18
-
19
- import tensorflow as tf, tf_keras
20
-
21
- from sentencepiece import SentencePieceTrainer
22
- from official.nlp.data import wmt_dataloader
23
-
24
-
25
- def _generate_line_file(filepath, lines):
26
- with tf.io.gfile.GFile(filepath, 'w') as f:
27
- for l in lines:
28
- f.write('{}\n'.format(l))
29
-
30
-
31
- def _generate_record_file(filepath, src_lines, tgt_lines, unique_id=False):
32
- writer = tf.io.TFRecordWriter(filepath)
33
- for i, (src, tgt) in enumerate(zip(src_lines, tgt_lines)):
34
- features = {
35
- 'en': tf.train.Feature(
36
- bytes_list=tf.train.BytesList(
37
- value=[src.encode()])),
38
- 'reverse_en': tf.train.Feature(
39
- bytes_list=tf.train.BytesList(
40
- value=[tgt.encode()])),
41
- }
42
- if unique_id:
43
- features['unique_id'] = tf.train.Feature(
44
- int64_list=tf.train.Int64List(value=[i]))
45
- example = tf.train.Example(
46
- features=tf.train.Features(
47
- feature=features))
48
- writer.write(example.SerializeToString())
49
- writer.close()
50
-
51
-
52
- def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
53
- argstr = ' '.join([
54
- f'--input={input_path}', f'--vocab_size={vocab_size}',
55
- '--character_coverage=0.995',
56
- f'--model_prefix={model_path}', '--model_type=bpe',
57
- '--bos_id=-1', '--pad_id=0', f'--eos_id={eos_id}', '--unk_id=2'
58
- ])
59
- SentencePieceTrainer.Train(argstr)
60
-
61
-
62
- class WMTDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
63
-
64
- def setUp(self):
65
- super(WMTDataLoaderTest, self).setUp()
66
- self._temp_dir = self.get_temp_dir()
67
- src_lines = [
68
- 'abc ede fg',
69
- 'bbcd ef a g',
70
- 'de f a a g'
71
- ]
72
- tgt_lines = [
73
- 'dd cc a ef g',
74
- 'bcd ef a g',
75
- 'gef cd ba'
76
- ]
77
- self._record_train_input_path = os.path.join(self._temp_dir, 'train.record')
78
- _generate_record_file(self._record_train_input_path, src_lines, tgt_lines)
79
- self._record_test_input_path = os.path.join(self._temp_dir, 'test.record')
80
- _generate_record_file(self._record_test_input_path, src_lines, tgt_lines,
81
- unique_id=True)
82
- self._sentencepeice_input_path = os.path.join(self._temp_dir, 'inputs.txt')
83
- _generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
84
- sentencepeice_model_prefix = os.path.join(self._temp_dir, 'sp')
85
- _train_sentencepiece(self._sentencepeice_input_path, 20,
86
- sentencepeice_model_prefix)
87
- self._sentencepeice_model_path = '{}.model'.format(
88
- sentencepeice_model_prefix)
89
-
90
- @parameterized.named_parameters(
91
- ('train_static', True, True, 100, (2, 35)),
92
- ('train_non_static', True, False, 100, (12, 7)),
93
- ('non_train_static', False, True, 3, (3, 35)),
94
- ('non_train_non_static', False, False, 50, (2, 7)),)
95
- def test_load_dataset(
96
- self, is_training, static_batch, batch_size, expected_shape):
97
- data_config = wmt_dataloader.WMTDataConfig(
98
- input_path=self._record_train_input_path
99
- if is_training else self._record_test_input_path,
100
- max_seq_length=35,
101
- global_batch_size=batch_size,
102
- is_training=is_training,
103
- static_batch=static_batch,
104
- src_lang='en',
105
- tgt_lang='reverse_en',
106
- sentencepiece_model_path=self._sentencepeice_model_path)
107
- dataset = wmt_dataloader.WMTDataLoader(data_config).load()
108
- examples = next(iter(dataset))
109
- inputs, targets = examples['inputs'], examples['targets']
110
- self.assertEqual(inputs.shape, expected_shape)
111
- self.assertEqual(targets.shape, expected_shape)
112
-
113
- def test_load_dataset_raise_invalid_window(self):
114
- batch_tokens_size = 10 # this is too small to form buckets.
115
- data_config = wmt_dataloader.WMTDataConfig(
116
- input_path=self._record_train_input_path,
117
- max_seq_length=100,
118
- global_batch_size=batch_tokens_size,
119
- is_training=True,
120
- static_batch=False,
121
- src_lang='en',
122
- tgt_lang='reverse_en',
123
- sentencepiece_model_path=self._sentencepeice_model_path)
124
- with self.assertRaisesRegex(
125
- ValueError, 'The token budget, global batch size, is too small.*'):
126
- _ = wmt_dataloader.WMTDataLoader(data_config).load()
127
-
128
-
129
- if __name__ == '__main__':
130
- tf.test.main()