Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import tempfile | |
import six | |
import tensorflow as tf, tf_keras | |
from official.nlp.tools import tokenization | |
class TokenizationTest(tf.test.TestCase): | |
"""Tokenization test. | |
The implementation is forked from | |
https://github.com/google-research/bert/blob/master/tokenization_test.py." | |
""" | |
def test_full_tokenizer(self): | |
vocab_tokens = [ | |
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", | |
"##ing", "," | |
] | |
with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: | |
if six.PY2: | |
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) | |
else: | |
vocab_writer.write("".join([x + "\n" for x in vocab_tokens | |
]).encode("utf-8")) | |
vocab_file = vocab_writer.name | |
tokenizer = tokenization.FullTokenizer(vocab_file) | |
os.unlink(vocab_file) | |
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") | |
self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) | |
self.assertAllEqual( | |
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) | |
def test_chinese(self): | |
tokenizer = tokenization.BasicTokenizer() | |
self.assertAllEqual( | |
tokenizer.tokenize(u"ah\u535A\u63A8zz"), | |
[u"ah", u"\u535A", u"\u63A8", u"zz"]) | |
def test_basic_tokenizer_lower(self): | |
tokenizer = tokenization.BasicTokenizer(do_lower_case=True) | |
self.assertAllEqual( | |
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), | |
["hello", "!", "how", "are", "you", "?"]) | |
self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) | |
def test_basic_tokenizer_no_lower(self): | |
tokenizer = tokenization.BasicTokenizer(do_lower_case=False) | |
self.assertAllEqual( | |
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), | |
["HeLLo", "!", "how", "Are", "yoU", "?"]) | |
def test_basic_tokenizer_no_split_on_punc(self): | |
tokenizer = tokenization.BasicTokenizer( | |
do_lower_case=True, split_on_punc=False) | |
self.assertAllEqual( | |
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), | |
["hello!how", "are", "you?"]) | |
def test_wordpiece_tokenizer(self): | |
vocab_tokens = [ | |
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", | |
"##ing", "##!", "!" | |
] | |
vocab = {} | |
for (i, token) in enumerate(vocab_tokens): | |
vocab[token] = i | |
tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) | |
self.assertAllEqual(tokenizer.tokenize(""), []) | |
self.assertAllEqual( | |
tokenizer.tokenize("unwanted running"), | |
["un", "##want", "##ed", "runn", "##ing"]) | |
self.assertAllEqual( | |
tokenizer.tokenize("unwanted running !"), | |
["un", "##want", "##ed", "runn", "##ing", "!"]) | |
self.assertAllEqual( | |
tokenizer.tokenize("unwanted running!"), | |
["un", "##want", "##ed", "runn", "##ing", "##!"]) | |
self.assertAllEqual( | |
tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) | |
def test_convert_tokens_to_ids(self): | |
vocab_tokens = [ | |
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", | |
"##ing" | |
] | |
vocab = {} | |
for (i, token) in enumerate(vocab_tokens): | |
vocab[token] = i | |
self.assertAllEqual( | |
tokenization.convert_tokens_to_ids( | |
vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) | |
def test_is_whitespace(self): | |
self.assertTrue(tokenization._is_whitespace(u" ")) | |
self.assertTrue(tokenization._is_whitespace(u"\t")) | |
self.assertTrue(tokenization._is_whitespace(u"\r")) | |
self.assertTrue(tokenization._is_whitespace(u"\n")) | |
self.assertTrue(tokenization._is_whitespace(u"\u00A0")) | |
self.assertFalse(tokenization._is_whitespace(u"A")) | |
self.assertFalse(tokenization._is_whitespace(u"-")) | |
def test_is_control(self): | |
self.assertTrue(tokenization._is_control(u"\u0005")) | |
self.assertFalse(tokenization._is_control(u"A")) | |
self.assertFalse(tokenization._is_control(u" ")) | |
self.assertFalse(tokenization._is_control(u"\t")) | |
self.assertFalse(tokenization._is_control(u"\r")) | |
self.assertFalse(tokenization._is_control(u"\U0001F4A9")) | |
def test_is_punctuation(self): | |
self.assertTrue(tokenization._is_punctuation(u"-")) | |
self.assertTrue(tokenization._is_punctuation(u"$")) | |
self.assertTrue(tokenization._is_punctuation(u"`")) | |
self.assertTrue(tokenization._is_punctuation(u".")) | |
self.assertFalse(tokenization._is_punctuation(u"A")) | |
self.assertFalse(tokenization._is_punctuation(u" ")) | |
if __name__ == "__main__": | |
tf.test.main() | |