deanna-emery's picture
updates
93528c6
raw
history blame
6.86 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Transformer model."""
from absl import logging
from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.nlp.modeling.layers import attention
from official.nlp.modeling.models import seq2seq_transformer
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase):
def _build_model(
self,
padded_decode,
decode_max_length,
embedding_width,
self_attention_cls=None,
cross_attention_cls=None,
):
num_layers = 1
num_attention_heads = 2
intermediate_size = 32
vocab_size = 100
encdec_kwargs = dict(
num_layers=num_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
activation="relu",
dropout_rate=0.01,
attention_dropout_rate=0.01,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6,
intermediate_dropout=0.01)
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs)
decoder_layer = seq2seq_transformer.TransformerDecoder(
**encdec_kwargs,
self_attention_cls=self_attention_cls,
cross_attention_cls=cross_attention_cls
)
return seq2seq_transformer.Seq2SeqTransformer(
vocab_size=vocab_size,
embedding_width=embedding_width,
dropout_rate=0.01,
padded_decode=padded_decode,
decode_max_length=decode_max_length,
beam_size=4,
alpha=0.6,
encoder_layer=encoder_layer,
decoder_layer=decoder_layer)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
],
embed=[True, False],
is_training=[True, False],
custom_self_attention=[False, True],
custom_cross_attention=[False, True],
mode="eager"))
def test_create_model_with_ds(
self,
distribution,
embed,
is_training,
custom_self_attention,
custom_cross_attention,
):
self_attention_called = False
cross_attention_called = False
class SelfAttention(attention.CachedAttention):
"""Dummy implementation of custom attention."""
def __call__(
self, *args, **kwargs
):
nonlocal self_attention_called
self_attention_called = True
return super().__call__(*args, **kwargs)
class CrossAttention:
"""Dummy implementation of custom attention."""
def __init__(self, *args, **kwargs):
pass
def __call__(self, query, value, attention_mask, **kwargs):
nonlocal cross_attention_called
cross_attention_called = True
return query
with distribution.scope():
padded_decode = isinstance(
distribution,
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy))
decode_max_length = 10
batch_size = 4
embedding_width = 16
model = self._build_model(
padded_decode,
decode_max_length,
embedding_width,
self_attention_cls=SelfAttention if custom_self_attention else None,
cross_attention_cls=CrossAttention
if custom_cross_attention
else None,
)
@tf.function
def step(inputs):
def _step_fn(inputs):
return model(inputs)
outputs = distribution.run(_step_fn, args=(inputs,))
return tf.nest.map_structure(distribution.experimental_local_results,
outputs)
if embed:
fake_inputs = dict(
embedded_inputs=np.zeros(
(batch_size, decode_max_length, embedding_width),
dtype=np.float32),
input_masks=np.ones((batch_size, decode_max_length), dtype=bool))
else:
fake_inputs = dict(
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32))
if is_training:
fake_inputs["targets"] = np.zeros((batch_size, 8), dtype=np.int32)
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs[0].shape, (4, 8, 100))
else:
local_outputs = step(fake_inputs)
logging.info("local_outputs=%s", local_outputs)
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10))
self.assertEqual(self_attention_called, custom_self_attention)
self.assertEqual(cross_attention_called, custom_cross_attention)
@parameterized.parameters(True, False)
def test_create_savedmodel(self, padded_decode):
decode_max_length = 10
embedding_width = 16
model = self._build_model(
padded_decode, decode_max_length, embedding_width)
class SaveModule(tf.Module):
def __init__(self, model):
super(SaveModule, self).__init__()
self.model = model
@tf.function
def serve(self, inputs):
return self.model.call(dict(inputs=inputs))
@tf.function
def embedded_serve(self, embedded_inputs, input_masks):
return self.model.call(
dict(embedded_inputs=embedded_inputs, input_masks=input_masks))
save_module = SaveModule(model)
if padded_decode:
tensor_shape = (4, decode_max_length)
embedded_tensor_shape = (4, decode_max_length, embedding_width)
else:
tensor_shape = (None, None)
embedded_tensor_shape = (None, None, embedding_width)
signatures = dict(
serving_default=save_module.serve.get_concrete_function(
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")),
embedded_serving=save_module.embedded_serve.get_concrete_function(
tf.TensorSpec(
shape=embedded_tensor_shape, dtype=tf.float32,
name="embedded_inputs"),
tf.TensorSpec(
shape=tensor_shape, dtype=tf.bool, name="input_masks"),
))
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures)
if __name__ == "__main__":
tf.test.main()