# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils used to manipulate tensor shapes.""" import tensorflow as tf, tf_keras def assert_shape_equal(shape_a, shape_b): """Asserts that shape_a and shape_b are equal. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)): if shape_a != shape_b: raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) else: return tf.no_op() else: return tf.assert_equal(shape_a, shape_b) def combined_static_and_dynamic_shape(tensor): """Returns a list containing static and dynamic values for the dimensions. Returns a list of static and dynamic values for shape dimensions. This is useful to preserve static shapes when available in reshape operation. Args: tensor: A tensor of any type. Returns: A list of size tensor.shape.ndims containing integers or a scalar tensor. """ static_tensor_shape = tensor.shape.as_list() dynamic_tensor_shape = tf.shape(input=tensor) combined_shape = [] for index, dim in enumerate(static_tensor_shape): if dim is not None: combined_shape.append(dim) else: combined_shape.append(dynamic_tensor_shape[index]) return combined_shape def pad_or_clip_nd(tensor, output_shape): """Pad or Clip given tensor to the output shape. Args: tensor: Input tensor to pad or clip. output_shape: A list of integers / scalar tensors (or None for dynamic dim) representing the size to pad or clip each dimension of the input tensor. Returns: Input tensor padded and clipped to the output shape. """ tensor_shape = tf.shape(input=tensor) clip_size = [ tf.where(tensor_shape[i] - shape > 0, shape, -1) if shape is not None else -1 for i, shape in enumerate(output_shape) ] clipped_tensor = tf.slice( tensor, begin=tf.zeros(len(clip_size), dtype=tf.int32), size=clip_size) # Pad tensor if the shape of clipped tensor is smaller than the expected # shape. clipped_tensor_shape = tf.shape(input=clipped_tensor) trailing_paddings = [ shape - clipped_tensor_shape[i] if shape is not None else 0 for i, shape in enumerate(output_shape) ] paddings = tf.stack( [tf.zeros(len(trailing_paddings), dtype=tf.int32), trailing_paddings], axis=1) padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings) output_static_shape = [ dim if not isinstance(dim, tf.Tensor) else None for dim in output_shape ] padded_tensor.set_shape(output_static_shape) return padded_tensor