Spaces:
Runtime error
Runtime error
ASL-MoViNet-T5-translator
/
official
/recommendation
/uplift
/layers
/encoders
/concat_features_test.py
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for concat_feature_encoder.""" | |
from absl.testing import parameterized | |
import tensorflow as tf, tf_keras | |
from official.recommendation.uplift import keras_test_case | |
from official.recommendation.uplift.layers.encoders import concat_features | |
class ConcatFeaturesTest(keras_test_case.KerasTestCase, parameterized.TestCase): | |
def test_layer_correctness(self, feature_names, inputs, expected_output): | |
layer = concat_features.ConcatFeatures(feature_names=feature_names) | |
self.assertAllClose(expected_output, layer(inputs)) | |
def test_layer_correctness_keras_inputs(self, inputs, expected_shape): | |
layer = concat_features.ConcatFeatures(feature_names=list(inputs.keys())) | |
output = layer(inputs) | |
KerasTensor = tf_keras.Input(shape=(1,)).__class__ # pylint: disable=invalid-name | |
self.assertIsInstance(output, KerasTensor) | |
self.assertEqual(tf.TensorShape(expected_shape), output.shape) | |
def test_layer_stability(self): | |
layer = concat_features.ConcatFeatures( | |
feature_names=["dense", "sparse", "ragged"] | |
) | |
inputs = { | |
"dense": tf.constant([-1.4, 2.0], shape=(2, 1)), | |
"sparse": tf.sparse.SparseTensor( | |
indices=[[0, 1], [1, 0]], | |
values=[2.718, 3.14], | |
dense_shape=[2, 2], | |
), | |
"ragged": tf.ragged.constant([[5, 7.77], [8]]), | |
"other_feature": tf.ones((2, 5)), | |
} | |
self.assertLayerStable(inputs=inputs, layer=layer) | |
def test_layer_savable(self): | |
layer = concat_features.ConcatFeatures( | |
feature_names=["dense", "sparse", "ragged"] | |
) | |
inputs = { | |
"dense": tf.constant([-1.4, 2.0], shape=(2, 1)), | |
"sparse": tf.sparse.SparseTensor( | |
indices=[[0, 1], [1, 0]], | |
values=[2.718, 3.14], | |
dense_shape=[2, 2], | |
), | |
"ragged": tf.ragged.constant([[5, 7.77], [8]]), | |
"other_feature": tf.ones((2, 5)), | |
} | |
self.assertLayerSavable(inputs=inputs, layer=layer) | |
def test_missing_input_features(self): | |
layer = concat_features.ConcatFeatures(feature_names=["feature"]) | |
with self.assertRaisesRegex( | |
ValueError, "Layer inputs is missing features*" | |
): | |
layer({"other_feature": tf.ones((3, 1))}) | |
def test_unsupported_tensor_type(self): | |
class TestType(tf.experimental.ExtensionType): | |
tensor: tf.Tensor | |
layer = concat_features.ConcatFeatures(feature_names=["feature"]) | |
with self.assertRaisesRegex(TypeError, "Got unsupported tensor shape type"): | |
layer({ | |
"feature": TestType(tensor=tf.ones((3, 1))), | |
"other_feature": tf.ones((3, 1)), | |
}) | |
def test_empty_feature_names_list(self): | |
with self.assertRaisesRegex( | |
ValueError, "feature_names must be a non-empty list" | |
): | |
concat_features.ConcatFeatures(feature_names=[]) | |
def test_non_string_feature_name(self): | |
with self.assertRaisesRegex( | |
TypeError, "feature_names must be a list of strings" | |
): | |
concat_features.ConcatFeatures(feature_names=["x", 1]) | |
def test_shape_mismatch(self, inputs): | |
layer = concat_features.ConcatFeatures(feature_names=list(inputs.keys())) | |
with self.assertRaisesRegex( | |
ValueError, | |
( | |
"All features from the feature_names set must be tensors with the" | |
" same shape except for the last dimension" | |
), | |
): | |
layer(inputs) | |
def test_rank_mismatch(self, inputs): | |
layer = concat_features.ConcatFeatures(feature_names=list(inputs.keys())) | |
with self.assertRaisesRegex( | |
ValueError, | |
( | |
"All features from the feature_names set must be tensors with the" | |
" same shape except for the last dimension" | |
), | |
): | |
layer(inputs) | |
def test_layer_config(self): | |
layer = concat_features.ConcatFeatures( | |
feature_names=["feature1", "feature2"], name="encoder", dtype=tf.float64 | |
) | |
self.assertLayerConfigurable(layer=layer, serializable=True) | |
if __name__ == "__main__": | |
tf.test.main() | |