# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for t5.""" from absl.testing import parameterized import numpy as np import tensorflow as tf, tf_keras from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations from official.nlp.modeling.models import t5 def _create_cache(batch_size, init_decode_length, num_heads, head_size, dtype=tf.float32): if num_heads is None: kv_shape = [batch_size, init_decode_length, head_size] else: kv_shape = [batch_size, init_decode_length, num_heads, head_size] return { "key": tf.zeros(kv_shape, dtype=dtype), "value": tf.zeros(kv_shape, dtype=dtype) } class ModulesTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_embed(self, dtype): l = t5.Embed(vocab_size=5, features=4, compute_dtype=dtype, name="foo") inputs = np.array([[2, 3], [1, 2]], dtype=np.int32) inputs = tf.convert_to_tensor(inputs) one_hot_outputs = l(inputs, one_hot=True) gather_outputs = l(inputs, one_hot=False) self.assertEqual(one_hot_outputs.shape, (2, 2, 4)) self.assertLen(l.trainable_variables, 1) self.assertAllClose(one_hot_outputs, gather_outputs) outputs = l.attend(query=tf.zeros((2, 2, 4), dtype)) self.assertEqual(outputs.shape, (2, 2, 5)) # Test initializers. l = t5.Embed( vocab_size=5, features=4, compute_dtype=dtype, name="foo", embeddings_initializer=tf_keras.initializers.Zeros()) self.assertAllClose(l(inputs), tf.zeros((2, 2, 4), dtype)) @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_rms_norm(self, dtype): l = t5.RMSNorm(hidden_size=4, epsilon=0.0, name="foo") inputs = tf.ones((2, 4), dtype=dtype) outputs = l(inputs) self.assertAllEqual(l(inputs), inputs) self.assertEqual(outputs.dtype, dtype) self.assertLen(l.trainable_variables, 1) self.assertIn("foo/scale", l.trainable_variables[0].name) @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_linear(self, dtype): l = t5.Linear( in_features=4, out_features=4, w_init=tf_keras.initializers.Ones(), name="foo") inputs = tf.ones((2, 4), dtype=dtype) outputs = l(inputs) self.assertEqual(outputs.shape, inputs.shape) self.assertEqual(outputs.dtype, dtype) self.assertLen(l.trainable_variables, 2) def test_linear3d(self): batch_size = 2 l = t5.Linear3D( in_features=4, out_features=4, num_heads=2, to_3d=True, w_init=tf_keras.initializers.Ones(), name="foo") inputs = np.ones((batch_size, 2, 4), dtype=np.float32) self.assertEqual(l(inputs).shape, (batch_size, 2, 2, 4)) l = t5.Linear3D( in_features=2, out_features=4, num_heads=2, to_3d=False, w_init=tf_keras.initializers.Ones(), name="foo") inputs = np.ones((batch_size, 2, 2, 2), dtype=np.float32) self.assertEqual(l(inputs).shape, (batch_size, 2, 4)) def test_ffn(self): inputs = np.ones((2, 4), dtype=np.float32) for activation in ["relu", "linear", "gelu", "swish"]: l = t5.FFN( d_model=4, d_ff=8, use_bias=True, dropout_rate=0.1, activations=[activation], name="foo") self.assertEqual(l(inputs).shape, inputs.shape) self.assertLen(l.trainable_variables, 4) l = t5.FFN( d_model=4, d_ff=8, dropout_rate=0.1, activations=["linear", "gelu"], name="bar") self.assertLen(l.trainable_variables, 3) self.assertEqual(l(inputs).shape, inputs.shape) @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_relative_position(self, dtype): l = t5.RelativePositionEmbedding( num_heads=4, bidirectional=False, embeddings_initializer=tf_keras.initializers.Ones(), compute_dtype=dtype, name="foo") self.assertEqual(l(4, 2).shape, (1, 4, 4, 2)) l = t5.RelativePositionEmbedding( num_heads=4, bidirectional=True, embeddings_initializer=tf_keras.initializers.Ones(), compute_dtype=dtype, name="bar") outputs = l(4, 2) self.assertEqual(outputs.shape, (1, 4, 4, 2)) self.assertEqual(outputs.dtype, dtype) def test_masks(self): causal_mask = t5.make_causal_mask(np.zeros((2, 5))) self.assertEqual(causal_mask.shape, (2, 1, 5, 5)) @combinations.generate( combinations.combine( distribution=[ strategy_combinations.default_strategy, strategy_combinations.cloud_tpu_strategy, ], mode="eager")) def test_attention(self, distribution): num_heads, head_size = 2, 4 from_seq_length, to_seq_length = 4, 6 batch_size = 2 pos_embed = t5.RelativePositionEmbedding( num_heads=4, bidirectional=False, embeddings_initializer=tf_keras.initializers.Ones(), name="pos_embed") position_bias = pos_embed(from_seq_length, from_seq_length) l = t5.MultiHeadAttention(d_model=4, d_kv=2, num_heads=4, dropout_rate=0.1) query = tf.convert_to_tensor( np.ones((batch_size, from_seq_length, 4), dtype=np.float32)) self.assertEqual( l(query, position_bias=position_bias)["context"].shape, query.shape) kv = tf.convert_to_tensor( np.ones((batch_size, to_seq_length, 4), dtype=np.float32)) position_bias = pos_embed(from_seq_length, to_seq_length) outputs = l(query, kv=kv, position_bias=position_bias) self.assertEqual(outputs["context"].shape, query.shape) with distribution.scope(): l = t5.MultiHeadAttention( d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1) @tf.function def step(inputs): def _step_fn(inputs): cache = _create_cache(batch_size, from_seq_length, num_heads, head_size) mask = t5.make_causal_mask(tf.ones((batch_size, 1))) return l( query=inputs, mask=mask, cache=cache, decode_position=decode_position) outputs = distribution.run(_step_fn, args=(inputs,)) return tf.nest.map_structure(distribution.experimental_local_results, outputs) decode_position = 2 query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32)) local_outputs = step(query) self.assertEqual(local_outputs["context"][0].shape, (2, 1, 4)) self.assertNotEqual( np.sum(local_outputs["cache"]["key"][0][:, decode_position, ...].numpy()), 0.0) class T5Test(tf.test.TestCase, parameterized.TestCase): @combinations.generate( combinations.combine( distribution=[ strategy_combinations.default_strategy, strategy_combinations.cloud_tpu_strategy, ], mode="eager")) def test_attention_layers(self, distribution): num_heads, head_size = 2, 2 from_seq_length = 4 # TPU decoding should pre-allocate the entire sequence. batch_size = 2 with distribution.scope(): pos_embed = t5.RelativePositionEmbedding( num_heads=head_size, bidirectional=False, embeddings_initializer=tf_keras.initializers.Ones(), name="pos_embed") l = t5.SelfAttention( d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1) decode_position = 2 @tf.function def step(inputs): def _step_fn(inputs): cache = _create_cache(batch_size, from_seq_length, num_heads, head_size) mask = t5.make_causal_mask(tf.ones((batch_size, 1))) position_bias = pos_embed(from_seq_length, from_seq_length) return l( hidden_states=inputs, cache=cache, attention_mask=mask, decode_position=decode_position, position_bias=position_bias) outputs = distribution.run(_step_fn, args=(inputs,)) return tf.nest.map_structure(distribution.experimental_local_results, outputs) query = tf.convert_to_tensor(np.ones((2, 1, 4), dtype=np.float32)) local_outputs = step(query) self.assertEqual(local_outputs["layer_output"][0].shape, (2, 1, 4)) self.assertNotEqual( np.sum( local_outputs["cache"]["key"][0][:, decode_position, :, :].numpy()), 0.0) l = t5.CrossAttention( d_model=4, d_kv=head_size, num_heads=num_heads, dropout_rate=0.1) to_seq_length = 6 query = tf.convert_to_tensor( np.ones((2, from_seq_length, 4), dtype=np.float32)) kv = tf.convert_to_tensor( np.ones((2, to_seq_length, 4), dtype=np.float32)) @tf.function def step_cross_attn(inputs): def _step_fn(inputs): query, kv = inputs mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, to_seq_length))) return l(hidden_states=query, kv=kv, attention_mask=mask) outputs = distribution.run(_step_fn, args=(inputs,)) return tf.nest.map_structure(distribution.experimental_local_results, outputs) local_outputs = step_cross_attn((query, kv)) self.assertEqual(local_outputs["layer_output"][0].shape, (2, from_seq_length, 4)) def test_encoder_block(self): batch_size = 2 from_seq_length = 5 d_model = 4 l = t5.EncoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo") pos_embed = t5.RelativePositionEmbedding( num_heads=2, bidirectional=True, embeddings_initializer=tf_keras.initializers.Ones(), name="bar") attention_mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, from_seq_length))) position_bias = pos_embed(from_seq_length, from_seq_length) inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32) outputs = l( inputs, attention_mask=attention_mask, position_bias=position_bias) self.assertEqual(outputs.shape, (batch_size, from_seq_length, d_model)) def test_encdec_block(self): batch_size = 2 from_seq_length = 5 to_seq_length = 3 d_model = 4 l = t5.EncDecoderBlock(d_model=4, d_kv=3, num_heads=2, d_ff=8, name="foo") pos_embed = t5.RelativePositionEmbedding( num_heads=2, bidirectional=True, embeddings_initializer=tf_keras.initializers.Ones(), name="bar") encoder_decoder_mask = t5.make_attention_mask( tf.ones((batch_size, from_seq_length)), tf.ones((batch_size, to_seq_length))) position_bias = pos_embed(from_seq_length, from_seq_length) inputs = tf.ones((batch_size, from_seq_length, d_model), dtype=tf.float32) encoder_hidden_states = tf.ones((batch_size, to_seq_length, d_model), dtype=tf.float32) outputs = l( inputs, encoder_hidden_states, encoder_decoder_mask=encoder_decoder_mask, position_bias=position_bias) self.assertEqual(outputs[0].shape, (batch_size, from_seq_length, d_model)) @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_encoder(self, dtype): config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf_keras.initializers.Ones(), relative_embeddings_initializer=tf_keras.initializers.Ones()) encoder = t5.Encoder(config, compute_dtype=dtype) encoded = encoder(tf.zeros((4, 8), dtype=tf.int32)) self.assertEqual(encoded.shape, (4, 8, config.d_model)) @parameterized.named_parameters(("return_score", True), ("not_return_score", False)) def test_encoder_att_scores(self, return_attention_scores): config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf_keras.initializers.Ones(), relative_embeddings_initializer=tf_keras.initializers.Ones(), return_attention_scores=return_attention_scores) encoder = t5.Encoder(config, compute_dtype=tf.float32) encoded = encoder(tf.zeros((4, 8), dtype=tf.int32)) if return_attention_scores: encoded, scores = encoded self.assertEqual(encoded.shape, (4, 8, config.d_model)) self.assertIsNotNone(scores) self.assertLen(scores, 2) self.assertEqual(scores[0].shape, (4, 4, 8, 8)) else: self.assertEqual(encoded.shape, (4, 8, config.d_model)) @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_encoder_with_dense(self, dtype): config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf_keras.initializers.Ones(), relative_embeddings_initializer=tf_keras.initializers.Ones()) encoder = t5.Encoder(config, compute_dtype=dtype) encoded = encoder( tf.zeros((4, 8), dtype=tf.int32), dense_inputs=tf.ones((4, 2, 4), dtype=dtype)) self.assertEqual(encoded.shape, (4, 10, config.d_model)) @parameterized.named_parameters(("bfloat16", tf.bfloat16), ("float32", tf.float32)) def test_encoder_only_dense(self, dtype): config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf_keras.initializers.Ones(), relative_embeddings_initializer=tf_keras.initializers.Ones()) encoder = t5.Encoder(config, compute_dtype=dtype) encoded = encoder(dense_inputs=tf.ones((4, 2, 4), dtype=dtype)) self.assertEqual(encoded.shape, (4, 2, config.d_model)) def test_decoder(self): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=2, d_model=4, d_kv=3, num_heads=4, d_ff=16, vocab_size=10, vocab_embeddings_initializer=tf_keras.initializers.Ones(), relative_embeddings_initializer=tf_keras.initializers.Ones()) decoder = t5.Decoder(config) batch_size = 4 targets = tf.zeros((4, 8), dtype=tf.int32) encoded = tf.zeros((4, 8, config.d_model), dtype=tf.float32) outputs = decoder(targets, encoded) logits = outputs["logits"] self.assertEqual(logits.shape, (4, 8, config.vocab_size)) cache = {} cache[0] = _create_cache(batch_size, max_decode_len, config.num_heads, config.d_kv) cache[1] = _create_cache(batch_size, max_decode_len, config.num_heads, config.d_kv) targets = tf.zeros((4, 1), dtype=tf.int32) outputs = decoder( targets, encoded, decode_position=2, cache=cache, decode=True, max_decode_len=max_decode_len) logits = outputs["logits"] cache = outputs["cache"] self.assertEqual(logits.shape, (batch_size, 1, config.vocab_size)) for entry in cache.values(): for tensor in entry.values(): self.assertNotAllEqual(tensor.numpy()[:, 2, :, :], 0.0) @parameterized.named_parameters( ("t5_10", ("relu",), True, 26, False, tf.float32), ("t5_11", ("gelu", "linear"), False, 29, False, tf.float32), ("t5_10_bfloat16", ("relu",), True, 26, False, tf.bfloat16), ("t5_11_bfloat16", ("gelu", "linear"), False, 29, False, tf.bfloat16), ("t5_10_layer_sharing", ("relu",), True, 26, True, tf.float32), ("t5_11_layer_sharing", ("gelu", "linear"), False, 29, True, tf.float32), ("t5_10_bfloat16_layer_sharing", ("relu",), True, 26, True, tf.bfloat16), ("t5_11_bfloat16_layer_sharing", ("gelu", "linear"), False, 29, True, tf.bfloat16)) def test_transformer(self, ffn_activations, logits_via_embedding, expect_num_variables, layer_sharing, dtype): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=1, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, layer_sharing=layer_sharing, ffn_activations=ffn_activations, logits_via_embedding=logits_via_embedding) transformer = t5.T5Transformer(config, compute_dtype=dtype) self.assertLen(transformer.trainable_variables, expect_num_variables) inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])) outputs = transformer( encoder_input_tokens=inputs, decoder_input_tokens=inputs, decoder_target_tokens=inputs, encoder_segment_ids=segments, decoder_segment_ids=segments) cache = {} batch_size = 2 cache[0] = _create_cache( batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype) outputs = transformer.decode( encoder_input_tokens=inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: self.assertEqual(v.dtype, tf.float32) def test_transformer_return_attn_scores(self): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=1, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, layer_sharing=False, ffn_activations=("relu",), logits_via_embedding=True, return_attention_scores=True, ) transformer = t5.T5Transformer(config, compute_dtype=tf.float32) self.assertLen(transformer.trainable_variables, 26) inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]]) ) segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]]) ) outputs = transformer( encoder_input_tokens=inputs, decoder_input_tokens=inputs, decoder_target_tokens=inputs, encoder_segment_ids=segments, decoder_segment_ids=segments, ) self.assertIn("attention_scores", outputs) self.assertLen(outputs["attention_scores"], 1) self.assertEqual(outputs["attention_scores"][0].shape, (2, 4, 6, 6)) cache = {} batch_size = 2 cache[0] = _create_cache( batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=tf.float32, ) outputs = transformer.decode( encoder_input_tokens=inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: self.assertEqual(v.dtype, tf.float32) @parameterized.named_parameters( ("t5_10_dense", ("relu",), True, 26, False, tf.float32),) def test_transformer_with_dense(self, ffn_activations, logits_via_embedding, expect_num_variables, layer_sharing, dtype): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=1, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, layer_sharing=layer_sharing, ffn_activations=ffn_activations, logits_via_embedding=logits_via_embedding) transformer = t5.T5Transformer(config, compute_dtype=dtype) self.assertLen(transformer.trainable_variables, expect_num_variables) inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])) dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype) dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]])) outputs = transformer( encoder_input_tokens=inputs, encoder_dense_inputs=dense_inputs, decoder_input_tokens=inputs, decoder_target_tokens=inputs, encoder_segment_ids=segments, encoder_dense_segment_ids=dense_segments, decoder_segment_ids=segments) cache = {} batch_size = 2 cache[0] = _create_cache( batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype) outputs = transformer.decode( encoder_input_tokens=inputs, encoder_dense_inputs=dense_inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: self.assertEqual(v.dtype, tf.float32) @parameterized.named_parameters( ("t5_10_dense_layerwise_relpos", ("relu",), True, 26, False, tf.float32, False, 1), ("t5_10_dense_shared_relpos_d2", ("relu",), True, 39, False, tf.float32, True, 2), ("t5_10_dense_layerwise_relpos_d2", ("relu",), True, 40, False, tf.float32, False, 2), ) def test_transformer_with_lw_relpos(self, ffn_activations, logits_via_embedding, expect_num_variables, layer_sharing, dtype, use_shared_relpos, num_decoder_layers): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=1, num_decoder_layers=num_decoder_layers, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, layer_sharing=layer_sharing, ffn_activations=ffn_activations, logits_via_embedding=logits_via_embedding, use_shared_relative_position_bias=use_shared_relpos) transformer = t5.T5Transformer(config, compute_dtype=dtype) self.assertLen(transformer.trainable_variables, expect_num_variables) inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])) dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype) dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]])) outputs = transformer( encoder_input_tokens=inputs, encoder_dense_inputs=dense_inputs, decoder_input_tokens=inputs, decoder_target_tokens=inputs, encoder_segment_ids=segments, encoder_dense_segment_ids=dense_segments, decoder_segment_ids=segments) cache = {} batch_size = 2 for i in range(num_decoder_layers): cache[i] = _create_cache( batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype) outputs = transformer.decode( encoder_input_tokens=inputs, encoder_dense_inputs=dense_inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: self.assertEqual(v.dtype, tf.float32) @parameterized.named_parameters( ("t5_10", ("relu",), True, 26, False, tf.float32),) def test_transformer_with_dense_only(self, ffn_activations, logits_via_embedding, expect_num_variables, layer_sharing, dtype): max_decode_len = 10 config = t5.T5TransformerParams( num_layers=1, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, layer_sharing=layer_sharing, ffn_activations=ffn_activations, logits_via_embedding=logits_via_embedding) transformer = t5.T5Transformer(config, compute_dtype=dtype) self.assertLen(transformer.trainable_variables, expect_num_variables) decoder_inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) decoder_segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])) dense_inputs = tf.convert_to_tensor(np.random.randn(2, 2, 8), dtype=dtype) dense_segments = tf.convert_to_tensor(np.array([[1, 2], [1, 2]])) outputs = transformer( encoder_dense_inputs=dense_inputs, encoder_dense_segment_ids=dense_segments, decoder_input_tokens=decoder_inputs, decoder_target_tokens=decoder_inputs, decoder_segment_ids=decoder_segments) cache = {} batch_size = 2 cache[0] = _create_cache( batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype) outputs = transformer.decode( encoder_dense_inputs=dense_inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: self.assertEqual(v.dtype, tf.float32) @parameterized.named_parameters( ("t5_10", ("relu",), True, 39, tf.float32, 2), ("t5_10_bfloat16", ("relu",), True, 39, tf.bfloat16, 2)) def test_transformer_different_num_decoder_layers(self, ffn_activations, logits_via_embedding, expect_num_variables, dtype, num_decoder_layers): max_decode_len = 10 config = t5.T5TransformerParams( num_decoder_layers=num_decoder_layers, num_layers=1, d_model=8, d_kv=4, num_heads=4, d_ff=32, vocab_size=10, shared_embedding=True, ffn_activations=ffn_activations, logits_via_embedding=logits_via_embedding) transformer = t5.T5Transformer(config, compute_dtype=dtype) self.assertLen(transformer.trainable_variables, expect_num_variables) inputs = tf.convert_to_tensor( np.array([[2, 2, 1, 3, 1, 0], [3, 3, 1, 2, 2, 1]])) segments = tf.convert_to_tensor( np.array([[1, 1, 1, 2, 2, 0], [1, 1, 1, 2, 2, 2]])) outputs = transformer( encoder_input_tokens=inputs, decoder_input_tokens=inputs, decoder_target_tokens=inputs, encoder_segment_ids=segments, decoder_segment_ids=segments) cache = {} batch_size = 2 for i in range(num_decoder_layers): cache[i] = _create_cache( batch_size, max_decode_len, config.num_heads, config.d_kv, dtype=dtype) outputs = transformer.decode( encoder_input_tokens=inputs, encoded=outputs["encoded"], decoder_target_tokens=tf.ones((batch_size, 1), dtype=tf.int32), decode_position=1, decode=True, max_decode_len=max_decode_len, cache=cache) self.assertEqual(outputs["logits"].shape, (batch_size, 1, config.vocab_size)) for v in transformer.trainable_variables: self.assertEqual(v.dtype, tf.float32) if __name__ == "__main__": tf.test.main()