# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry responsible for built-in keras classes."""

import logging
import warnings

import tensorflow as tf

from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
from tensorflow_model_optimization.python.core.keras.compat import keras
from tensorflow_model_optimization.python.core.quantization.keras import quant_ops
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantizers


layers = keras.layers
K = keras.backend

CLUSTER_CENTROIDS = 'cluster_centroids_tf'
PULLING_INDICES = 'pulling_indices_tf'
ORIGINAL_WEIGHTS = 'ori_weights_vars_tf'
WEIGHT_NAME = 'weight_name'
CLUSTERING_IMPL = 'clst_impl'
CENTROIDS_MASK = 'centroids_mask'
SPARSITY_MASK = 'sparsity_mask'


def get_unique(t):
  """Get unique values and lookup index from N-D tensor.

  Args:
    t: tensor
  Returns:
    centroids (unique values), lookup index (same shape as input tensor)
  Example:
    t:
    ([[1.0, 2.0],
      [2.0, 3.0],
      [3.0, 3.0],
      [1.0, 2.0]]
    )
    centroids(unique values):
    ([1.0, 2.0, 3.0])
    output final index:
    ([[0, 1],
      [1, 2],
      [2, 2],
      [0, 1]]
    )
  """
  t_flatten = tf.reshape(t, shape=(-1,))
  uniques, index = tf.unique(t_flatten)
  return uniques, tf.reshape(index, shape=tf.shape(t))


def get_centroids(layer, weight, data_format):
  """Gets centroid infos from the weights of a layer.

  Args:
    layer: The Keras layer from which the weight belong.
    weight: The weight tensor to get the centroids info from.
    data_format: string to indicate format: "channels_first" or "channels_last".
  Returns:
    A 4-tuple of centroids (unique values), number of centroids, lookup index,
    whether to cluster per channel (boolean).
  """
  cluster_per_channel = layer.layer and isinstance(
      layer.layer, keras.layers.Conv2D
  )

  if not cluster_per_channel:
    centroids, index = get_unique(weight)
    return centroids, tf.size(centroids), index, False

  # In case of cluster_per_channel we need to extract
  # unique values (centroids) for each channel.
  num_channels = weight.shape[1 if data_format == 'channels_first' else -1]
  channel_centroids = []
  channel_indices = []
  num_centroids = []

  for channel in range(num_channels):
    channel_weights = weight[:, :, :, channel]
    centroids, indices = get_unique(channel_weights)

    channel_centroids.append(centroids)
    channel_indices.append(indices)
    num_centroids.append(tf.size(centroids))

  max_centroid = max(num_centroids)
  max_diff = max_centroid - min(num_centroids)

  if max_diff > 1:
    centroids, index = get_unique(weight)
    return centroids, tf.size(centroids), index, False

  for i, centroid in enumerate(channel_centroids):
    if num_centroids[i] != max_centroid:
      one_padding = tf.ones([max_centroid - num_centroids[i]])
      channel_centroids[i] = tf.concat([centroid, one_padding], 0)

  centroids = tf.convert_to_tensor(channel_centroids)
  lookup = tf.convert_to_tensor(channel_indices)

  lookup = tf.transpose(
      lookup,
      perm=(1, 0, 2, 3) if data_format == 'channels_first' else (1, 2, 3, 0))

  return centroids, max_centroid, lookup, True


class _ClusterPreserveInfo(object):
  """ClusterPreserveInfo."""

  def __init__(self, weight_attrs, quantize_config_attrs):
    """ClusterPreserveInfo.

    Args:
      weight_attrs: list of cluster preservable weight attributes of layer.
      quantize_config_attrs: list of quantization configuration class name.
    """
    self.weight_attrs = weight_attrs
    self.quantize_config_attrs = quantize_config_attrs


class ClusterPreserveQuantizeRegistry(object):
  """ClusterPreserveQuantizeRegistry is for built-in keras layers."""
  # The keys represent built-in keras layers; the first values represent the
  # the variables within the layers which hold the kernel weights, second
  # values represent the class name of quantization configuration for layers.
  # This decide the weights of layers with quantization configurations are
  # cluster preservable.
  _LAYERS_CONFIG_MAP = {
      layers.Conv2D:
      _ClusterPreserveInfo(['kernel'], ['Default8BitConvQuantizeConfig']),
      layers.Dense:
      _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),

      # DepthwiseConv2D is supported with 8bit qat, but not with
      # clustering, thus for DepthwiseConv2D CQAT,
      # preserving clustered weights is disabled.
      layers.DepthwiseConv2D:
      _ClusterPreserveInfo(['depthwise_kernel'],
                           ['Default8BitQuantizeConfig']),

      # layers that are supported with clustering, but not yet with qat
      # layers.Conv1D:
      # _ClusterPreserveInfo(['kernel'], []),
      # layers.Conv2DTranspose:
      # _ClusterPreserveInfo(['kernel'], []),
      # layers.Conv3D:
      # _ClusterPreserveInfo(['kernel'], []),
      # layers.Conv3DTranspose:
      # _ClusterPreserveInfo(['kernel'], []),
      # layers.LocallyConnected1D:
      # _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),
      # layers.LocallyConnected2D:
      # _ClusterPreserveInfo(['kernel'], ['Default8BitQuantizeConfig']),

      # SeparableConv need verify from 8bit qat
      # layers.SeparableConv1D:
      # _ClusterPreserveInfo(['pointwise_kernel'],
      #                      ['Default8BitConvQuantizeConfig']),
      # layers.SeparableConv2D:
      # _ClusterPreserveInfo(['pointwise_kernel'],
      #                      ['Default8BitConvQuantizeConfig']),

      # Embedding need verify from 8bit qat
      # layers.Embedding: _ClusterPreserveInfo(['embeddings'], []),
  }

  _DISABLE_CLUSTER_PRESERVE = frozenset({
      layers.DepthwiseConv2D,
  })

  def __init__(self, preserve_sparsity):
    self._config_quantizer_map = {
        'Default8BitQuantizeConfig':
        ClusterPreserveDefault8BitWeightsQuantizer(preserve_sparsity),
        'Default8BitConvQuantizeConfig':
        ClusterPreserveDefault8BitConvWeightsQuantizer(preserve_sparsity),
    }

  @classmethod
  def _no_trainable_weights(cls, layer):
    """Returns whether this layer has trainable weights.

    Args:
      layer: The layer to check for trainable weights.
    Returns:
      True/False whether the layer has trainable weights.
    """
    return not layer.trainable_weights

  @classmethod
  def _disable_cluster_preserve(cls, layer):
    """Returns whether to disable this layer for preserving clusters.

    Args:
      layer: The layer to check for disabling.
    Returns:
      True/False whether disabling this layer for preserving clusters.
    """
    return layer.__class__ in cls._DISABLE_CLUSTER_PRESERVE

  @classmethod
  def supports(cls, layer):
    """Returns whether the registry supports this layer type.

    Args:
      layer: The layer to check for support.
    Returns:
      True/False whether the layer type is supported.
    """
    # layers without trainable weights are consider supported,
    # e.g., ReLU, Softmax, and AveragePooling2D.
    if cls._no_trainable_weights(layer):
      return True

    if layer.__class__ in cls._LAYERS_CONFIG_MAP:
      return True

    return False

  @classmethod
  def _weight_names(cls, layer):

    if cls._no_trainable_weights(layer):
      return []

    return cls._LAYERS_CONFIG_MAP[layer.__class__].weight_attrs

  def apply_cluster_preserve_quantize_config(self, layer, quantize_config):
    """Applies cluster-preserve weight quantizer.

    Args:
      layer: The layer to check for support.
      quantize_config: quantization config for supporting cluster preservation
      on clustered weights
    Returns:
      The quantize_config with addon cluster preserve weight_quantizer.
    """
    if not self.supports(layer):
      raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')

    # Example: ReLU, Softmax, and AveragePooling2D (without trainable weights)
    # DepthwiseConv2D (cluster_preserve is disabled)
    if self._no_trainable_weights(layer) or self._disable_cluster_preserve(
        layer):
      return quantize_config

    # Example: Conv2D, Dense layers
    if quantize_config.__class__.__name__ in self._LAYERS_CONFIG_MAP[
        layer.__class__].quantize_config_attrs:
      quantize_config.weight_quantizer = self._config_quantizer_map[
          quantize_config.__class__.__name__]
    else:
      raise ValueError('Configuration ' +
                       str(quantize_config.__class__.__name__) +
                       ' is not supported for Layer ' + str(layer.__class__) +
                       '.')

    return quantize_config


class Default8bitClusterPreserveQuantizeRegistry(
    ClusterPreserveQuantizeRegistry):
  """Default 8 bit ClusterPreserveQuantizeRegistry."""

  def get_quantize_config(self, layer):
    """Returns the quantization config with weight_quantizer for a given layer.

    Args:
      layer: input layer to return quantize config for.
    Returns:
      Returns the quantization config for cluster preserve weight_quantizer.
    """
    quantize_config = (default_8bit_quantize_registry.
                       Default8BitQuantizeRegistry().
                       get_quantize_config(layer))
    cluster_aware_quantize_config = super(
        Default8bitClusterPreserveQuantizeRegistry,
        self).apply_cluster_preserve_quantize_config(layer, quantize_config)

    return cluster_aware_quantize_config


class ClusterPreserveDefaultWeightsQuantizer(quantizers.LastValueQuantizer):
  """Quantize weights while preserving clusters."""

  def __init__(
      self, num_bits, per_axis, symmetric, narrow_range, preserve_sparsity):
    """ClusterPreserveDefaultWeightsQuantizer.

    Args:
      num_bits: Number of bits for quantization
      per_axis: Whether to apply per_axis quantization. The last dimension is
        used as the axis.
      symmetric: If true, use symmetric quantization limits instead of training
        the minimum and maximum of each quantization range separately.
      narrow_range: In case of 8 bits, narrow_range nudges the quantized range
        to be [-127, 127] instead of [-128, 127]. This ensures symmetric
        range has 0 as the centre.
      preserve_sparsity: Whether to apply prune-cluster-preserving quantization
        aware training.
    """
    super(ClusterPreserveDefaultWeightsQuantizer, self).__init__(
        num_bits=num_bits,
        per_axis=per_axis,
        symmetric=symmetric,
        narrow_range=narrow_range,
    )
    self.preserve_sparsity = preserve_sparsity

  def _build_clusters(self, name, layer):
    """Extracts the cluster centroids and cluster indices.

    Extracts cluster centroids and cluster indices from the pretrained
    clustered model when the input layer is clustered.

    Args:
      name: Name of weights in layer.
      layer: Quantization wrapped keras layer.
    Returns:
      A dictionary of the initial values of the
      cluster centroids, cluster indices, original weights,
      the pretrained flag for marking the first training
      epoch, and weight name.
    """
    result = {}
    weights = getattr(layer.layer, name)
    if self.preserve_sparsity and not tf.reduce_any(weights == 0):
      self.preserve_sparsity = False
      logging.warning(
          'Input layer does not contain zero weights, so apply CQAT instead.')
    centroids_mask = None

    # Detects whether layer is convolutional and is clustered per channel
    data_format = getattr(layer.layer, 'data_format', None)
    centroids, num_centroids, lookup, cluster_per_channel = get_centroids(
        layer, weights, data_format)

    if self.preserve_sparsity:
      sparsity_mask = tf.math.divide_no_nan(weights, weights)
      zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
      centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
      result = {SPARSITY_MASK: sparsity_mask}

    # Prepare clustering variables for the Keras graph when clusters
    # exist, assuming we do not use number_of_clusters larger than 1024
    if num_centroids > 1024:
      warnings.warn(f'No clustering performed on layer {layer.name}.\n'
                    f'Too many centroids to cluster.')
      return result
    # If not enough clusters, we do not preserve clustering
    elif num_centroids <= 1:
      warnings.warn(f'No clustering performed on layer {layer.name}.\n'
                    f'Perhaps too many clusters requested for this layer?')
      return result
    else:
      clst_centroids_tf = layer.add_weight(
          CLUSTER_CENTROIDS,
          shape=centroids.shape,
          initializer=keras.initializers.Constant(
              value=K.batch_get_value([centroids])[0]
          ),
          dtype=centroids.dtype,
          trainable=True,
      )

      ori_weights_tf = layer.add_weight(
          ORIGINAL_WEIGHTS,
          shape=weights.shape,
          initializer=keras.initializers.Constant(
              value=K.batch_get_value([weights])[0]
          ),
          dtype=weights.dtype,
          trainable=True,
      )

      # Get clustering implementation according to layer type
      clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
      ).get_clustering_impl(
          layer.layer, name, cluster_per_channel=cluster_per_channel)
      clustering_impl = clustering_impl_cls(
          clst_centroids_tf, cluster_config.GradientAggregation.SUM,
          data_format)

      pulling_indices = tf.dtypes.cast(
          clustering_impl.get_pulling_indices(ori_weights_tf),
          lookup.dtype
      )

      pulling_indices_tf = layer.add_weight(
          PULLING_INDICES,
          shape=lookup.shape,
          initializer=keras.initializers.Constant(
              value=K.batch_get_value([pulling_indices])[0]
          ),
          dtype=lookup.dtype,
          trainable=False,
      )

      result_clst = {
          CLUSTER_CENTROIDS: clst_centroids_tf,
          PULLING_INDICES: pulling_indices_tf,
          ORIGINAL_WEIGHTS: ori_weights_tf,
          WEIGHT_NAME: name,
          CLUSTERING_IMPL: clustering_impl,
          CENTROIDS_MASK: centroids_mask,
      }
      result.update(result_clst)
      return result

  def build(self, tensor_shape, name, layer):
    """Build (P)CQAT wrapper.

    When preserve_sparsity is true and the input is clustered.

    Args:
      tensor_shape: Shape of weights which needs to be quantized.
      name: Name of weights in layer.
      layer: Quantization wrapped keras layer.
    Returns:
      Dictionary of centroids, indices and
      quantization params, the dictionary will be passed
      to __call__ function.
    """
    # To get all the initial values from pretrained clustered model
    result = self._build_clusters(name, layer)
    # Result can have clustering nodes, then this is CQAT
    # Result can have both clustering nodes and sparsity mask, then
    # this will be PCQAT
    result.update(
        super(ClusterPreserveDefaultWeightsQuantizer,
              self).build(tensor_shape, name, layer))

    return result

  def __call__(self, inputs, training, weights, **kwargs):
    """Apply cluster preserved quantization to the input tensor.

    Args:
      inputs: Input tensor (layer's weights) to be quantized.
      training: Whether the graph is currently training.
      weights: Dictionary of weights (params) the quantizer can use to
        quantize the tensor (layer's weights). This contains the weights
        created in the `build` function.
      **kwargs: Additional variables which may be passed to the quantizer.
    Returns:
      quantized tensor.
    """
    if training:
      if CLUSTER_CENTROIDS in weights:
        if self.preserve_sparsity:
          weights[ORIGINAL_WEIGHTS].assign(
              tf.multiply(weights[ORIGINAL_WEIGHTS],
                          weights[SPARSITY_MASK]))
          weights[CLUSTERING_IMPL].cluster_centroids.assign(
              weights[CLUSTERING_IMPL].
              cluster_centroids * weights[CENTROIDS_MASK]
          )
          weights[CLUSTER_CENTROIDS].assign(
              weights[CLUSTERING_IMPL].cluster_centroids
          )
        # Insert clustering variables
        weights[PULLING_INDICES].assign(tf.dtypes.cast(
            weights[CLUSTERING_IMPL].get_pulling_indices(
                weights[ORIGINAL_WEIGHTS]),
            weights[PULLING_INDICES].dtype
        ))

        output = weights[CLUSTERING_IMPL].get_clustered_weight(
            weights[PULLING_INDICES], weights[ORIGINAL_WEIGHTS])
        inputs.assign(output)
      else:
        if self.preserve_sparsity:
          inputs = tf.multiply(inputs, weights[SPARSITY_MASK])
        output = inputs
    else:
      output = inputs

    return quant_ops.LastValueQuantize(
        output,
        weights['min_var'],
        weights['max_var'],
        is_training=training,
        num_bits=self.num_bits,
        per_channel=self.per_axis,
        symmetric=self.symmetric,
        narrow_range=self.narrow_range
    )


class ClusterPreserveDefault8BitWeightsQuantizer(
    ClusterPreserveDefaultWeightsQuantizer):
  """ClusterPreserveWeightsQuantizer for default 8bit weights."""

  def __init__(self, preserve_sparsity):
    super(ClusterPreserveDefault8BitWeightsQuantizer,
          self).__init__(num_bits=8,
                         per_axis=False,
                         symmetric=True,
                         narrow_range=True,
                         preserve_sparsity=preserve_sparsity)
    self.preserve_sparsity = preserve_sparsity


class ClusterPreserveDefault8BitConvWeightsQuantizer(
    ClusterPreserveDefaultWeightsQuantizer,
    default_8bit_quantizers.Default8BitConvWeightsQuantizer):
  """ClusterPreserveWeightsQuantizer for default 8bit Conv2D weights."""

  def __init__(self, preserve_sparsity):  # pylint: disable=super-init-not-called
    default_8bit_quantizers.Default8BitConvWeightsQuantizer.__init__(self)
    self.preserve_sparsity = preserve_sparsity

  def build(self, tensor_shape, name, layer):
    result = ClusterPreserveDefaultWeightsQuantizer._build_clusters(
        self, name, layer)
    result.update(
        default_8bit_quantizers.Default8BitConvWeightsQuantizer.build(
            self, tensor_shape, name, layer))
    return result