# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A Classification head layer which is common used with sequence encoders.""" import tensorflow as tf, tf_keras from official.modeling import tf_utils from official.nlp.modeling.layers import gaussian_process from official.nlp.modeling.layers import spectral_normalization class ClassificationHead(tf_keras.layers.Layer): """Pooling head for sentence-level classification tasks.""" def __init__(self, inner_dim, num_classes, cls_token_idx=0, activation="tanh", dropout_rate=0.0, initializer="glorot_uniform", **kwargs): """Initializes the `ClassificationHead`. Args: inner_dim: The dimensionality of inner projection layer. If 0 or `None` then only the output projection layer is created. num_classes: Number of output classes. cls_token_idx: The index inside the sequence to pool. activation: Dense layer activation. dropout_rate: Dropout probability. initializer: Initializer for dense layer kernels. **kwargs: Keyword arguments. """ super().__init__(**kwargs) self.dropout_rate = dropout_rate self.inner_dim = inner_dim self.num_classes = num_classes self.activation = tf_utils.get_activation(activation) self.initializer = tf_keras.initializers.get(initializer) self.cls_token_idx = cls_token_idx if self.inner_dim: self.dense = tf_keras.layers.Dense( units=self.inner_dim, activation=self.activation, kernel_initializer=tf_utils.clone_initializer(self.initializer), name="pooler_dense") self.dropout = tf_keras.layers.Dropout(rate=self.dropout_rate) self.out_proj = tf_keras.layers.Dense( units=num_classes, kernel_initializer=tf_utils.clone_initializer(self.initializer), name="logits") def call(self, features: tf.Tensor, only_project: bool = False): """Implements call(). Args: features: a rank-3 Tensor when self.inner_dim is specified, otherwise it is a rank-2 Tensor. only_project: a boolean. If True, we return the intermediate Tensor before projecting to class logits. Returns: a Tensor, if only_project is True, shape= [batch size, hidden size]. If only_project is False, shape= [batch size, num classes]. """ if not self.inner_dim: x = features else: x = features[:, self.cls_token_idx, :] # take token. x = self.dense(x) if only_project: return x x = self.dropout(x) x = self.out_proj(x) return x def get_config(self): config = { "cls_token_idx": self.cls_token_idx, "dropout_rate": self.dropout_rate, "num_classes": self.num_classes, "inner_dim": self.inner_dim, "activation": tf_keras.activations.serialize(self.activation), "initializer": tf_keras.initializers.serialize(self.initializer), } config.update(super(ClassificationHead, self).get_config()) return config @classmethod def from_config(cls, config, custom_objects=None): return cls(**config) @property def checkpoint_items(self): return {self.dense.name: self.dense} class MultiClsHeads(tf_keras.layers.Layer): """Pooling heads sharing the same pooling stem.""" def __init__(self, inner_dim, cls_list, cls_token_idx=0, activation="tanh", dropout_rate=0.0, initializer="glorot_uniform", **kwargs): """Initializes the `MultiClsHeads`. Args: inner_dim: The dimensionality of inner projection layer. If 0 or `None` then only the output projection layer is created. cls_list: a list of pairs of (classification problem name and the numbers of classes. cls_token_idx: The index inside the sequence to pool. activation: Dense layer activation. dropout_rate: Dropout probability. initializer: Initializer for dense layer kernels. **kwargs: Keyword arguments. """ super().__init__(**kwargs) self.dropout_rate = dropout_rate self.inner_dim = inner_dim self.cls_list = cls_list self.activation = tf_utils.get_activation(activation) self.initializer = tf_keras.initializers.get(initializer) self.cls_token_idx = cls_token_idx if self.inner_dim: self.dense = tf_keras.layers.Dense( units=inner_dim, activation=self.activation, kernel_initializer=tf_utils.clone_initializer(self.initializer), name="pooler_dense") self.dropout = tf_keras.layers.Dropout(rate=self.dropout_rate) self.out_projs = [] for name, num_classes in cls_list: self.out_projs.append( tf_keras.layers.Dense( units=num_classes, kernel_initializer=tf_utils.clone_initializer(self.initializer), name=name)) def call(self, features: tf.Tensor, only_project: bool = False): """Implements call(). Args: features: a rank-3 Tensor when self.inner_dim is specified, otherwise it is a rank-2 Tensor. only_project: a boolean. If True, we return the intermediate Tensor before projecting to class logits. Returns: If only_project is True, a Tensor with shape= [batch size, hidden size]. If only_project is False, a dictionary of Tensors. """ if not self.inner_dim: x = features else: x = features[:, self.cls_token_idx, :] # take token. x = self.dense(x) if only_project: return x x = self.dropout(x) outputs = {} for proj_layer in self.out_projs: outputs[proj_layer.name] = proj_layer(x) return outputs def get_config(self): config = { "dropout_rate": self.dropout_rate, "cls_token_idx": self.cls_token_idx, "cls_list": self.cls_list, "inner_dim": self.inner_dim, "activation": tf_keras.activations.serialize(self.activation), "initializer": tf_keras.initializers.serialize(self.initializer), } config.update(super().get_config()) return config @classmethod def from_config(cls, config, custom_objects=None): return cls(**config) @property def checkpoint_items(self): items = {self.dense.name: self.dense} items.update({v.name: v for v in self.out_projs}) return items class GaussianProcessClassificationHead(ClassificationHead): """Gaussian process-based pooling head for sentence classification. This class implements a classifier head for BERT encoder that is based on the spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple method to improve a neural network's uncertainty quantification ability without sacrificing accuracy or lantency. It applies spectral normalization to the hidden pooler layer, and then replaces the dense output layer with a Gaussian process. [1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness. In _Neural Information Processing Systems_, 2020. https://arxiv.org/abs/2006.10108 """ def __init__(self, inner_dim, num_classes, cls_token_idx=0, activation="tanh", dropout_rate=0.0, initializer="glorot_uniform", use_spec_norm=True, use_gp_layer=True, temperature=None, **kwargs): """Initializes the `GaussianProcessClassificationHead`. Args: inner_dim: The dimensionality of inner projection layer. If 0 or `None` then only the output projection layer is created. num_classes: Number of output classes. cls_token_idx: The index inside the sequence to pool. activation: Dense layer activation. dropout_rate: Dropout probability. initializer: Initializer for dense layer kernels. use_spec_norm: Whether to apply spectral normalization to pooler layer. use_gp_layer: Whether to use Gaussian process as the output layer. temperature: The temperature parameter to be used for mean-field approximation during inference. If None then no mean-field adjustment is applied. **kwargs: Additional keyword arguments. """ # Collects spectral normalization and Gaussian process args from kwargs. self.use_spec_norm = use_spec_norm self.use_gp_layer = use_gp_layer self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs) self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs) self.temperature = temperature super().__init__( inner_dim=inner_dim, num_classes=num_classes, cls_token_idx=cls_token_idx, activation=activation, dropout_rate=dropout_rate, initializer=initializer, **kwargs) # Applies spectral normalization to the dense pooler layer. if self.use_spec_norm and hasattr(self, "dense"): self.dense = spectral_normalization.SpectralNormalization( self.dense, inhere_layer_name=True, **self.spec_norm_kwargs) # Replace Dense output layer with the Gaussian process layer. if use_gp_layer: self.out_proj = gaussian_process.RandomFeatureGaussianProcess( self.num_classes, kernel_initializer=tf_utils.clone_initializer(self.initializer), name="logits", **self.gp_layer_kwargs) def call(self, features, training=False, return_covmat=False): """Returns model output. Dring training, the model returns raw logits. During evaluation, the model returns uncertainty adjusted logits, and (optionally) the covariance matrix. Arguments: features: A tensor of input features, shape (batch_size, feature_dim). training: Whether the model is in training mode. return_covmat: Whether the model should also return covariance matrix if `use_gp_layer=True`. During training, it is recommended to set `return_covmat=False` to be compatible with the standard Keras pipelines (e.g., `model.fit()`). Returns: logits: Uncertainty-adjusted predictive logits, shape (batch_size, num_classes). covmat: (Optional) Covariance matrix, shape (batch_size, batch_size). Returned only when return_covmat=True. """ logits = super().call(features) # Extracts logits and covariance matrix from model output. if self.use_gp_layer: logits, covmat = logits else: covmat = None # Computes the uncertainty-adjusted logits during evaluation. if not training: logits = gaussian_process.mean_field_logits( logits, covmat, mean_field_factor=self.temperature) if return_covmat and covmat is not None: return logits, covmat return logits def reset_covariance_matrix(self): """Resets covariance matrix of the Gaussian process layer.""" if hasattr(self.out_proj, "reset_covariance_matrix"): self.out_proj.reset_covariance_matrix() def get_config(self): config = dict( use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer) config.update(self.spec_norm_kwargs) config.update(self.gp_layer_kwargs) config["temperature"] = self.temperature config.update(super(GaussianProcessClassificationHead, self).get_config()) return config def extract_gp_layer_kwargs(kwargs): """Extracts Gaussian process layer configs from a given kwarg.""" return dict( num_inducing=kwargs.pop("num_inducing", 1024), normalize_input=kwargs.pop("normalize_input", True), gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999), gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1.), scale_random_features=kwargs.pop("scale_random_features", False), l2_regularization=kwargs.pop("l2_regularization", 1e-6), gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"), return_gp_cov=kwargs.pop("return_gp_cov", True), return_random_features=kwargs.pop("return_random_features", False), use_custom_random_features=kwargs.pop("use_custom_random_features", True), custom_random_features_initializer=kwargs.pop( "custom_random_features_initializer", "random_normal"), custom_random_features_activation=kwargs.pop( "custom_random_features_activation", None)) def extract_spec_norm_kwargs(kwargs): """Extracts spectral normalization configs from a given kwarg.""" return dict( iteration=kwargs.pop("iteration", 1), norm_multiplier=kwargs.pop("norm_multiplier", .99)) class PerQueryDenseHead(tf_keras.layers.Layer): """Pooling head used for EncT5 style models. This module projects each query to use a different projection. For a input shape= [bs, num_queries, hidden_size], it projects each query to (features). Ending up with shape= [bs, num_queries, features]. For example, for classification with a few classes, one may use num_queries as 1 and features as number of classes. For multilabel classification, one may use num_queries as number of classes and features as 2. So each query represents a binary classification of one label. """ def __init__(self, num_queries: int, features: int, use_bias: bool = False, kernel_initializer: str = "glorot_uniform", **kwargs): """Initializes the `PerQueryDenseHead`. Args: num_queries: number of queries (the learnable embeddings in the input sequences) from the decoder. features: int with numbers of output features. Each query with be projected to this number with a different projection. use_bias: whether to add a bias to the output. kernel_initializer: Initializer for dense layer kernels. **kwargs: Keyword arguments. """ super().__init__(**kwargs) self.num_queries = num_queries self.features = features self.use_bias = use_bias self.kernel_initializer = tf_keras.initializers.get(kernel_initializer) def build(self, input_shape): input_shape = tf.TensorShape(input_shape) # Hidden size. last_dim = tf.compat.dimension_value(input_shape[-1]) self.hidden_size = last_dim self.kernel = self.add_weight( "kernel", shape=[self.num_queries, last_dim, self.features], initializer=self.kernel_initializer, dtype=self.dtype, trainable=True) if self.use_bias: self.bias = self.add_weight( "bias", shape=[ self.num_queries, self.features, ], dtype=self.dtype, trainable=True) else: self.bias = None def call(self, inputs: tf.Tensor) -> tf.Tensor: """Implements call(). Args: inputs: a rank-3 Tensor of shape= [bs, num_queries, hidden_size]. Returns: A Tensor, shape= [batch size, num_queries, features]. """ outputs = tf.einsum("bqh,qhf->bqf", inputs, self.kernel) if self.use_bias: outputs += self.bias return outputs def get_config(self): config = { "num_queries": self.num_queries, "features": self.features, "kernel_initializer": tf_keras.activations.serialize(self.kernel_initializer), } config.update(super(PerQueryDenseHead, self).get_config()) return config @classmethod def from_config(cls, config, custom_objects=None): return cls(**config)