Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Test Transformer model.""" | |
from absl import logging | |
from absl.testing import parameterized | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
from tensorflow.python.distribute import combinations | |
from tensorflow.python.distribute import strategy_combinations | |
from official.nlp.modeling.layers import attention | |
from official.nlp.modeling.models import seq2seq_transformer | |
class Seq2SeqTransformerTest(tf.test.TestCase, parameterized.TestCase): | |
def _build_model( | |
self, | |
padded_decode, | |
decode_max_length, | |
embedding_width, | |
self_attention_cls=None, | |
cross_attention_cls=None, | |
): | |
num_layers = 1 | |
num_attention_heads = 2 | |
intermediate_size = 32 | |
vocab_size = 100 | |
encdec_kwargs = dict( | |
num_layers=num_layers, | |
num_attention_heads=num_attention_heads, | |
intermediate_size=intermediate_size, | |
activation="relu", | |
dropout_rate=0.01, | |
attention_dropout_rate=0.01, | |
use_bias=False, | |
norm_first=True, | |
norm_epsilon=1e-6, | |
intermediate_dropout=0.01) | |
encoder_layer = seq2seq_transformer.TransformerEncoder(**encdec_kwargs) | |
decoder_layer = seq2seq_transformer.TransformerDecoder( | |
**encdec_kwargs, | |
self_attention_cls=self_attention_cls, | |
cross_attention_cls=cross_attention_cls | |
) | |
return seq2seq_transformer.Seq2SeqTransformer( | |
vocab_size=vocab_size, | |
embedding_width=embedding_width, | |
dropout_rate=0.01, | |
padded_decode=padded_decode, | |
decode_max_length=decode_max_length, | |
beam_size=4, | |
alpha=0.6, | |
encoder_layer=encoder_layer, | |
decoder_layer=decoder_layer) | |
def test_create_model_with_ds( | |
self, | |
distribution, | |
embed, | |
is_training, | |
custom_self_attention, | |
custom_cross_attention, | |
): | |
self_attention_called = False | |
cross_attention_called = False | |
class SelfAttention(attention.CachedAttention): | |
"""Dummy implementation of custom attention.""" | |
def __call__( | |
self, *args, **kwargs | |
): | |
nonlocal self_attention_called | |
self_attention_called = True | |
return super().__call__(*args, **kwargs) | |
class CrossAttention: | |
"""Dummy implementation of custom attention.""" | |
def __init__(self, *args, **kwargs): | |
pass | |
def __call__(self, query, value, attention_mask, **kwargs): | |
nonlocal cross_attention_called | |
cross_attention_called = True | |
return query | |
with distribution.scope(): | |
padded_decode = isinstance( | |
distribution, | |
(tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)) | |
decode_max_length = 10 | |
batch_size = 4 | |
embedding_width = 16 | |
model = self._build_model( | |
padded_decode, | |
decode_max_length, | |
embedding_width, | |
self_attention_cls=SelfAttention if custom_self_attention else None, | |
cross_attention_cls=CrossAttention | |
if custom_cross_attention | |
else None, | |
) | |
def step(inputs): | |
def _step_fn(inputs): | |
return model(inputs) | |
outputs = distribution.run(_step_fn, args=(inputs,)) | |
return tf.nest.map_structure(distribution.experimental_local_results, | |
outputs) | |
if embed: | |
fake_inputs = dict( | |
embedded_inputs=np.zeros( | |
(batch_size, decode_max_length, embedding_width), | |
dtype=np.float32), | |
input_masks=np.ones((batch_size, decode_max_length), dtype=bool)) | |
else: | |
fake_inputs = dict( | |
inputs=np.zeros((batch_size, decode_max_length), dtype=np.int32)) | |
if is_training: | |
fake_inputs["targets"] = np.zeros((batch_size, 8), dtype=np.int32) | |
local_outputs = step(fake_inputs) | |
logging.info("local_outputs=%s", local_outputs) | |
self.assertEqual(local_outputs[0].shape, (4, 8, 100)) | |
else: | |
local_outputs = step(fake_inputs) | |
logging.info("local_outputs=%s", local_outputs) | |
self.assertEqual(local_outputs["outputs"][0].shape, (4, 10)) | |
self.assertEqual(self_attention_called, custom_self_attention) | |
self.assertEqual(cross_attention_called, custom_cross_attention) | |
def test_create_savedmodel(self, padded_decode): | |
decode_max_length = 10 | |
embedding_width = 16 | |
model = self._build_model( | |
padded_decode, decode_max_length, embedding_width) | |
class SaveModule(tf.Module): | |
def __init__(self, model): | |
super(SaveModule, self).__init__() | |
self.model = model | |
def serve(self, inputs): | |
return self.model.call(dict(inputs=inputs)) | |
def embedded_serve(self, embedded_inputs, input_masks): | |
return self.model.call( | |
dict(embedded_inputs=embedded_inputs, input_masks=input_masks)) | |
save_module = SaveModule(model) | |
if padded_decode: | |
tensor_shape = (4, decode_max_length) | |
embedded_tensor_shape = (4, decode_max_length, embedding_width) | |
else: | |
tensor_shape = (None, None) | |
embedded_tensor_shape = (None, None, embedding_width) | |
signatures = dict( | |
serving_default=save_module.serve.get_concrete_function( | |
tf.TensorSpec(shape=tensor_shape, dtype=tf.int32, name="inputs")), | |
embedded_serving=save_module.embedded_serve.get_concrete_function( | |
tf.TensorSpec( | |
shape=embedded_tensor_shape, dtype=tf.float32, | |
name="embedded_inputs"), | |
tf.TensorSpec( | |
shape=tensor_shape, dtype=tf.bool, name="input_masks"), | |
)) | |
tf.saved_model.save(save_module, self.get_temp_dir(), signatures=signatures) | |
if __name__ == "__main__": | |
tf.test.main() | |