# Copyright 2019, The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

from tensorflow_model_optimization.python.core.internal.tensor_encoding.stages.research import misc
from tensorflow_model_optimization.python.core.internal.tensor_encoding.testing import test_utils


if tf.executing_eagerly():
  tf.compat.v1.disable_eager_execution()


class SplitBySmallValueEncodingStageTest(test_utils.BaseEncodingStageTest):

  def default_encoding_stage(self):
    """See base class."""
    return misc.SplitBySmallValueEncodingStage()

  def default_input(self):
    """See base class."""
    return tf.random.uniform([50], minval=-1.0, maxval=1.0)

  @property
  def is_lossless(self):
    """See base class."""
    return False

  def common_asserts_for_test_data(self, data):
    """See base class."""
    self._assert_is_integer(
        data.encoded_x[misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])

  def _assert_is_integer(self, indices):
    """Asserts that indices values are integers."""
    assert indices.dtype == np.int32

  @parameterized.parameters([tf.float32, tf.float64])
  def test_input_types(self, x_dtype):
    # Tests different input dtypes.
    x = tf.constant([1.0, 0.1, 0.01, 0.001, 0.0001], dtype=x_dtype)
    threshold = 0.05
    stage = misc.SplitBySmallValueEncodingStage(threshold=threshold)
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)
    test_data = test_utils.TestData(x, encoded_x, decoded_x)
    test_data = self.evaluate_test_data(test_data)

    self._assert_is_integer(test_data.encoded_x[
        misc.SplitBySmallValueEncodingStage.ENCODED_INDICES_KEY])

    # The numpy arrays must have the same dtype as the arrays from test_data.
    expected_encoded_values = np.array([1.0, 0.1], dtype=x.dtype.as_numpy_dtype)
    expected_encoded_indices = np.array([0, 1], dtype=np.int32)
    expected_decoded_x = np.array([1.0, 0.1, 0., 0., 0.],
                                  dtype=x_dtype.as_numpy_dtype)
    self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY],
                        expected_encoded_values)
    self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
                        expected_encoded_indices)
    self.assertAllEqual(test_data.decoded_x, expected_decoded_x)

  def test_all_zero_input_works(self):
    # Tests that encoding does not blow up with all-zero input. With all-zero
    # input, both of the encoded values will be empty arrays.
    stage = misc.SplitBySmallValueEncodingStage()
    test_data = self.run_one_to_many_encode_decode(stage,
                                                   lambda: tf.zeros([50]))

    self.assertAllEqual(np.zeros((50)).astype(np.float32), test_data.decoded_x)

  def test_all_below_threshold_works(self):
    # Tests that encoding does not blow up with all-below-threshold input. In
    # this case, both of the encoded values will be empty arrays.
    stage = misc.SplitBySmallValueEncodingStage(threshold=0.1)
    x = tf.random.uniform([50], minval=-0.01, maxval=0.01)
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)
    test_data = test_utils.TestData(x, encoded_x, decoded_x)
    test_data = self.evaluate_test_data(test_data)

    expected_encoded_indices = np.array([], dtype=np.int32).reshape([0])
    self.assertAllEqual(test_data.encoded_x[stage.ENCODED_VALUES_KEY], [])
    self.assertAllEqual(test_data.encoded_x[stage.ENCODED_INDICES_KEY],
                        expected_encoded_indices)
    self.assertAllEqual(test_data.decoded_x,
                        np.zeros([50], dtype=x.dtype.as_numpy_dtype))


class DifferenceBetweenIntegersEncodingStageTest(
    test_utils.BaseEncodingStageTest):

  def default_encoding_stage(self):
    """See base class."""
    return misc.DifferenceBetweenIntegersEncodingStage()

  def default_input(self):
    """See base class."""
    return tf.random.uniform([10], minval=0, maxval=10, dtype=tf.int64)

  @property
  def is_lossless(self):
    """See base class."""
    return True

  def common_asserts_for_test_data(self, data):
    """See base class."""
    self.assertAllEqual(data.x, data.decoded_x)

  @parameterized.parameters(
      itertools.product([[1,], [2,], [10,]], [tf.int32, tf.int64]))
  def test_with_multiple_input_shapes(self, input_dims, dtype):

    def x_fn():
      return tf.random.uniform(input_dims, minval=0, maxval=10, dtype=dtype)

    test_data = self.run_one_to_many_encode_decode(
        self.default_encoding_stage(), x_fn)
    self.common_asserts_for_test_data(test_data)

  def test_empty_input_static(self):
    # Tests that the encoding works when the input shape is [0].
    x = []
    x = tf.convert_to_tensor(x, dtype=tf.int32)
    assert x.shape.as_list() == [0]

    stage = self.default_encoding_stage()
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)

    test_data = self.evaluate_test_data(
        test_utils.TestData(x, encoded_x, decoded_x))
    self.common_asserts_for_test_data(test_data)

  def test_empty_input_dynamic(self):
    # Tests that the encoding works when the input shape is [0], but not
    # statically known.
    y = tf.zeros((10,))
    indices = tf.compat.v2.where(tf.abs(y) > 1e-8)
    x = tf.gather_nd(y, indices)
    x = tf.cast(x, tf.int32)  # Empty tensor.
    assert x.shape.as_list() == [None]
    stage = self.default_encoding_stage()
    encode_params, decode_params = stage.get_params()
    encoded_x, decoded_x = self.encode_decode_x(stage, x, encode_params,
                                                decode_params)

    test_data = self.evaluate_test_data(
        test_utils.TestData(x, encoded_x, decoded_x))
    assert test_data.x.shape == (0,)
    assert test_data.encoded_x[stage.ENCODED_VALUES_KEY].shape == (0,)
    assert test_data.decoded_x.shape == (0,)

  @parameterized.parameters([tf.bool, tf.float32])
  def test_encode_unsupported_type_raises(self, dtype):
    stage = self.default_encoding_stage()
    with self.assertRaisesRegexp(TypeError, 'Unsupported input type'):
      self.run_one_to_many_encode_decode(
          stage, lambda: tf.cast(self.default_input(), dtype))

  def test_encode_unsupported_input_shape_raises(self):
    x = tf.random.uniform((3, 4), maxval=10, dtype=tf.int32)
    stage = self.default_encoding_stage()
    params, _ = stage.get_params()
    with self.assertRaisesRegexp(ValueError, 'Number of dimensions must be 1'):
      stage.encode(x, params)


if __name__ == '__main__':
  tf.test.main()