deanna-emery's picture
updates
93528c6
raw
history blame
6.88 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.tasks.translation."""
import functools
import os
import orbit
import tensorflow as tf, tf_keras
from sentencepiece import SentencePieceTrainer
from official.nlp.data import wmt_dataloader
from official.nlp.tasks import translation
def _generate_line_file(filepath, lines):
with tf.io.gfile.GFile(filepath, "w") as f:
for l in lines:
f.write("{}\n".format(l))
def _generate_record_file(filepath, src_lines, tgt_lines):
writer = tf.io.TFRecordWriter(filepath)
for src, tgt in zip(src_lines, tgt_lines):
example = tf.train.Example(
features=tf.train.Features(
feature={
"en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[src.encode()])),
"reverse_en": tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[tgt.encode()])),
}))
writer.write(example.SerializeToString())
writer.close()
def _train_sentencepiece(input_path, vocab_size, model_path, eos_id=1):
argstr = " ".join([
f"--input={input_path}", f"--vocab_size={vocab_size}",
"--character_coverage=0.995",
f"--model_prefix={model_path}", "--model_type=bpe",
"--bos_id=-1", "--pad_id=0", f"--eos_id={eos_id}", "--unk_id=2"
])
SentencePieceTrainer.Train(argstr)
class TranslationTaskTest(tf.test.TestCase):
def setUp(self):
super(TranslationTaskTest, self).setUp()
self._temp_dir = self.get_temp_dir()
src_lines = [
"abc ede fg",
"bbcd ef a g",
"de f a a g"
]
tgt_lines = [
"dd cc a ef g",
"bcd ef a g",
"gef cd ba"
]
self._record_input_path = os.path.join(self._temp_dir, "inputs.record")
_generate_record_file(self._record_input_path, src_lines, tgt_lines)
self._sentencepeice_input_path = os.path.join(self._temp_dir, "inputs.txt")
_generate_line_file(self._sentencepeice_input_path, src_lines + tgt_lines)
sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp")
_train_sentencepiece(self._sentencepeice_input_path, 11,
sentencepeice_model_prefix)
self._sentencepeice_model_path = "{}.model".format(
sentencepeice_model_prefix)
def test_task(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
is_training=True, static_batch=True, global_batch_size=24,
max_seq_length=12),
sentencepiece_model_path=self._sentencepeice_model_path)
task = translation.TranslationTask(config)
model = task.build_model()
dataset = task.build_inputs(config.train_data)
iterator = iter(dataset)
optimizer = tf_keras.optimizers.SGD(lr=0.1)
task.train_step(next(iterator), model, optimizer)
def test_no_sentencepiece_path(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
is_training=True, static_batch=True, global_batch_size=4,
max_seq_length=4),
sentencepiece_model_path=None)
with self.assertRaisesRegex(
ValueError,
"Setencepiece model path not provided."):
translation.TranslationTask(config)
def test_sentencepiece_no_eos(self):
sentencepeice_model_prefix = os.path.join(self._temp_dir, "sp_no_eos")
_train_sentencepiece(self._sentencepeice_input_path, 20,
sentencepeice_model_prefix, eos_id=-1)
sentencepeice_model_path = "{}.model".format(
sentencepeice_model_prefix)
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
is_training=True, static_batch=True, global_batch_size=4,
max_seq_length=4),
sentencepiece_model_path=sentencepeice_model_path)
with self.assertRaisesRegex(
ValueError,
"EOS token not in tokenizer vocab.*"):
translation.TranslationTask(config)
def test_evaluation(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1),
padded_decode=False,
decode_max_length=64),
validation_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, src_lang="en",
tgt_lang="reverse_en", static_batch=True, global_batch_size=4),
sentencepiece_model_path=self._sentencepeice_model_path)
logging_dir = self.get_temp_dir()
task = translation.TranslationTask(config, logging_dir=logging_dir)
dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
task.build_inputs,
config.validation_data)
model = task.build_model()
strategy = tf.distribute.get_strategy()
aggregated = None
for data in dataset:
distributed_outputs = strategy.run(
functools.partial(task.validation_step, model=model),
args=(data,))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
distributed_outputs)
aggregated = task.aggregate_logs(state=aggregated, step_outputs=outputs)
metrics = task.reduce_aggregated_logs(aggregated)
self.assertIn("sacrebleu_score", metrics)
self.assertIn("bleu_score", metrics)
if __name__ == "__main__":
tf.test.main()