Pradeep Kumar commited on
Commit
01a6c3b
·
verified ·
1 Parent(s): 439fdba

Delete create_pretraining_data_test.py

Browse files
Files changed (1) hide show
  1. create_pretraining_data_test.py +0 -128
create_pretraining_data_test.py DELETED
@@ -1,128 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Tests for official.nlp.data.create_pretraining_data."""
16
- import random
17
-
18
- import tensorflow as tf, tf_keras
19
-
20
- from official.nlp.data import create_pretraining_data as cpd
21
-
22
- _VOCAB_WORDS = ["vocab_1", "vocab_2"]
23
-
24
-
25
- class CreatePretrainingDataTest(tf.test.TestCase):
26
-
27
- def assertTokens(self, input_tokens, output_tokens, masked_positions,
28
- masked_labels):
29
- # Ensure the masked positions are unique.
30
- self.assertCountEqual(masked_positions, set(masked_positions))
31
-
32
- # Ensure we can reconstruct the input from the output.
33
- reconstructed_tokens = output_tokens
34
- for pos, label in zip(masked_positions, masked_labels):
35
- reconstructed_tokens[pos] = label
36
- self.assertEqual(input_tokens, reconstructed_tokens)
37
-
38
- # Ensure each label is valid.
39
- for pos, label in zip(masked_positions, masked_labels):
40
- output_token = output_tokens[pos]
41
- if (output_token == "[MASK]" or output_token in _VOCAB_WORDS or
42
- output_token == input_tokens[pos]):
43
- continue
44
- self.fail("invalid mask value: {}".format(output_token))
45
-
46
- def test_tokens_to_grams(self):
47
- tests = [
48
- (["That", "cone"], [(0, 1), (1, 2)]),
49
- (["That", "cone", "##s"], [(0, 1), (1, 3)]),
50
- (["Swit", "##zer", "##land"], [(0, 3)]),
51
- (["[CLS]", "Up", "##dog"], [(1, 3)]),
52
- (["[CLS]", "Up", "##dog", "[SEP]", "Down"], [(1, 3), (4, 5)]),
53
- ]
54
- for inp, expected in tests:
55
- output = cpd._tokens_to_grams(inp)
56
- self.assertEqual(expected, output)
57
-
58
- def test_window(self):
59
- input_list = [1, 2, 3, 4]
60
- window_outputs = [
61
- (1, [[1], [2], [3], [4]]),
62
- (2, [[1, 2], [2, 3], [3, 4]]),
63
- (3, [[1, 2, 3], [2, 3, 4]]),
64
- (4, [[1, 2, 3, 4]]),
65
- (5, []),
66
- ]
67
- for window, expected in window_outputs:
68
- output = cpd._window(input_list, window)
69
- self.assertEqual(expected, list(output))
70
-
71
- def test_create_masked_lm_predictions(self):
72
- tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
73
- rng = random.Random(123)
74
- for _ in range(0, 5):
75
- output_tokens, masked_positions, masked_labels = (
76
- cpd.create_masked_lm_predictions(
77
- tokens=tokens,
78
- masked_lm_prob=1.0,
79
- max_predictions_per_seq=3,
80
- vocab_words=_VOCAB_WORDS,
81
- rng=rng,
82
- do_whole_word_mask=False,
83
- max_ngram_size=None))
84
- self.assertLen(masked_positions, 3)
85
- self.assertLen(masked_labels, 3)
86
- self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
87
-
88
- def test_create_masked_lm_predictions_whole_word(self):
89
- tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
90
- rng = random.Random(345)
91
- for _ in range(0, 5):
92
- output_tokens, masked_positions, masked_labels = (
93
- cpd.create_masked_lm_predictions(
94
- tokens=tokens,
95
- masked_lm_prob=1.0,
96
- max_predictions_per_seq=3,
97
- vocab_words=_VOCAB_WORDS,
98
- rng=rng,
99
- do_whole_word_mask=True,
100
- max_ngram_size=None))
101
- # since we can't get exactly three tokens without breaking a word we
102
- # only take two.
103
- self.assertLen(masked_positions, 2)
104
- self.assertLen(masked_labels, 2)
105
- self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
106
- # ensure that we took an entire word.
107
- self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])
108
-
109
- def test_create_masked_lm_predictions_ngram(self):
110
- tokens = ["[CLS]"] + ["tok{}".format(i) for i in range(0, 512)] + ["[SEP]"]
111
- rng = random.Random(345)
112
- for _ in range(0, 5):
113
- output_tokens, masked_positions, masked_labels = (
114
- cpd.create_masked_lm_predictions(
115
- tokens=tokens,
116
- masked_lm_prob=1.0,
117
- max_predictions_per_seq=76,
118
- vocab_words=_VOCAB_WORDS,
119
- rng=rng,
120
- do_whole_word_mask=True,
121
- max_ngram_size=3))
122
- self.assertLen(masked_positions, 76)
123
- self.assertLen(masked_labels, 76)
124
- self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
125
-
126
-
127
- if __name__ == "__main__":
128
- tf.test.main()