File size: 16,208 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A Classification head layer which is common used with sequence encoders."""

import tensorflow as tf, tf_keras

from official.modeling import tf_utils

from official.nlp.modeling.layers import gaussian_process
from official.nlp.modeling.layers import spectral_normalization


class ClassificationHead(tf_keras.layers.Layer):
  """Pooling head for sentence-level classification tasks."""

  def __init__(self,
               inner_dim,
               num_classes,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               **kwargs):
    """Initializes the `ClassificationHead`.

    Args:
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self.dropout_rate = dropout_rate
    self.inner_dim = inner_dim
    self.num_classes = num_classes
    self.activation = tf_utils.get_activation(activation)
    self.initializer = tf_keras.initializers.get(initializer)
    self.cls_token_idx = cls_token_idx

    if self.inner_dim:
      self.dense = tf_keras.layers.Dense(
          units=self.inner_dim,
          activation=self.activation,
          kernel_initializer=tf_utils.clone_initializer(self.initializer),
          name="pooler_dense")
    self.dropout = tf_keras.layers.Dropout(rate=self.dropout_rate)

    self.out_proj = tf_keras.layers.Dense(
        units=num_classes,
        kernel_initializer=tf_utils.clone_initializer(self.initializer),
        name="logits")

  def call(self, features: tf.Tensor, only_project: bool = False):
    """Implements call().

    Args:
      features: a rank-3 Tensor when self.inner_dim is specified, otherwise
        it is a rank-2 Tensor.
      only_project: a boolean. If True, we return the intermediate Tensor
        before projecting to class logits.

    Returns:
      a Tensor, if only_project is True, shape= [batch size, hidden size].
      If only_project is False, shape= [batch size, num classes].
    """
    if not self.inner_dim:
      x = features
    else:
      x = features[:, self.cls_token_idx, :]  # take <CLS> token.
      x = self.dense(x)

    if only_project:
      return x
    x = self.dropout(x)
    x = self.out_proj(x)
    return x

  def get_config(self):
    config = {
        "cls_token_idx": self.cls_token_idx,
        "dropout_rate": self.dropout_rate,
        "num_classes": self.num_classes,
        "inner_dim": self.inner_dim,
        "activation": tf_keras.activations.serialize(self.activation),
        "initializer": tf_keras.initializers.serialize(self.initializer),
    }
    config.update(super(ClassificationHead, self).get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
    return {self.dense.name: self.dense}


class MultiClsHeads(tf_keras.layers.Layer):
  """Pooling heads sharing the same pooling stem."""

  def __init__(self,
               inner_dim,
               cls_list,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               **kwargs):
    """Initializes the `MultiClsHeads`.

    Args:
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
      cls_list: a list of pairs of (classification problem name and the numbers
        of classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self.dropout_rate = dropout_rate
    self.inner_dim = inner_dim
    self.cls_list = cls_list
    self.activation = tf_utils.get_activation(activation)
    self.initializer = tf_keras.initializers.get(initializer)
    self.cls_token_idx = cls_token_idx

    if self.inner_dim:
      self.dense = tf_keras.layers.Dense(
          units=inner_dim,
          activation=self.activation,
          kernel_initializer=tf_utils.clone_initializer(self.initializer),
          name="pooler_dense")
    self.dropout = tf_keras.layers.Dropout(rate=self.dropout_rate)
    self.out_projs = []
    for name, num_classes in cls_list:
      self.out_projs.append(
          tf_keras.layers.Dense(
              units=num_classes,
              kernel_initializer=tf_utils.clone_initializer(self.initializer),
              name=name))

  def call(self, features: tf.Tensor, only_project: bool = False):
    """Implements call().

    Args:
      features: a rank-3 Tensor when self.inner_dim is specified, otherwise
        it is a rank-2 Tensor.
      only_project: a boolean. If True, we return the intermediate Tensor
        before projecting to class logits.

    Returns:
      If only_project is True, a Tensor with shape= [batch size, hidden size].
      If only_project is False, a dictionary of Tensors.
    """
    if not self.inner_dim:
      x = features
    else:
      x = features[:, self.cls_token_idx, :]  # take <CLS> token.
      x = self.dense(x)

    if only_project:
      return x
    x = self.dropout(x)

    outputs = {}
    for proj_layer in self.out_projs:
      outputs[proj_layer.name] = proj_layer(x)
    return outputs

  def get_config(self):
    config = {
        "dropout_rate": self.dropout_rate,
        "cls_token_idx": self.cls_token_idx,
        "cls_list": self.cls_list,
        "inner_dim": self.inner_dim,
        "activation": tf_keras.activations.serialize(self.activation),
        "initializer": tf_keras.initializers.serialize(self.initializer),
    }
    config.update(super().get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
    items = {self.dense.name: self.dense}
    items.update({v.name: v for v in self.out_projs})
    return items


class GaussianProcessClassificationHead(ClassificationHead):
  """Gaussian process-based pooling head for sentence classification.

  This class implements a classifier head for BERT encoder that is based on the
  spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple
  method to improve a neural network's uncertainty quantification ability
  without sacrificing accuracy or lantency. It applies spectral normalization to
  the hidden pooler layer, and then replaces the dense output layer with a
  Gaussian process.


  [1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with
       Deterministic Deep Learning via Distance Awareness.
       In _Neural Information Processing Systems_, 2020.
       https://arxiv.org/abs/2006.10108
  """

  def __init__(self,
               inner_dim,
               num_classes,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               use_spec_norm=True,
               use_gp_layer=True,
               temperature=None,
               **kwargs):
    """Initializes the `GaussianProcessClassificationHead`.

    Args:
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      use_spec_norm: Whether to apply spectral normalization to pooler layer.
      use_gp_layer: Whether to use Gaussian process as the output layer.
      temperature: The temperature parameter to be used for mean-field
        approximation during inference. If None then no mean-field adjustment is
        applied.
      **kwargs: Additional keyword arguments.
    """
    # Collects spectral normalization and Gaussian process args from kwargs.
    self.use_spec_norm = use_spec_norm
    self.use_gp_layer = use_gp_layer
    self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs)
    self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs)
    self.temperature = temperature

    super().__init__(
        inner_dim=inner_dim,
        num_classes=num_classes,
        cls_token_idx=cls_token_idx,
        activation=activation,
        dropout_rate=dropout_rate,
        initializer=initializer,
        **kwargs)

    # Applies spectral normalization to the dense pooler layer.
    if self.use_spec_norm and hasattr(self, "dense"):
      self.dense = spectral_normalization.SpectralNormalization(
          self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)

    # Replace Dense output layer with the Gaussian process layer.
    if use_gp_layer:
      self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
          self.num_classes,
          kernel_initializer=tf_utils.clone_initializer(self.initializer),
          name="logits",
          **self.gp_layer_kwargs)

  def call(self, features, training=False, return_covmat=False):
    """Returns model output.

    Dring training, the model returns raw logits. During evaluation, the model
    returns uncertainty adjusted logits, and (optionally) the covariance matrix.

    Arguments:
      features: A tensor of input features, shape (batch_size, feature_dim).
      training: Whether the model is in training mode.
      return_covmat: Whether the model should also return covariance matrix if
        `use_gp_layer=True`. During training, it is recommended to set
        `return_covmat=False` to be compatible with the standard Keras pipelines
        (e.g., `model.fit()`).

    Returns:
      logits: Uncertainty-adjusted predictive logits, shape
        (batch_size, num_classes).
      covmat: (Optional) Covariance matrix, shape (batch_size, batch_size).
        Returned only when return_covmat=True.
    """
    logits = super().call(features)

    # Extracts logits and covariance matrix from model output.
    if self.use_gp_layer:
      logits, covmat = logits
    else:
      covmat = None

    # Computes the uncertainty-adjusted logits during evaluation.
    if not training:
      logits = gaussian_process.mean_field_logits(
          logits, covmat, mean_field_factor=self.temperature)

    if return_covmat and covmat is not None:
      return logits, covmat
    return logits

  def reset_covariance_matrix(self):
    """Resets covariance matrix of the Gaussian process layer."""
    if hasattr(self.out_proj, "reset_covariance_matrix"):
      self.out_proj.reset_covariance_matrix()

  def get_config(self):
    config = dict(
        use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer)

    config.update(self.spec_norm_kwargs)
    config.update(self.gp_layer_kwargs)
    config["temperature"] = self.temperature

    config.update(super(GaussianProcessClassificationHead, self).get_config())
    return config


def extract_gp_layer_kwargs(kwargs):
  """Extracts Gaussian process layer configs from a given kwarg."""

  return dict(
      num_inducing=kwargs.pop("num_inducing", 1024),
      normalize_input=kwargs.pop("normalize_input", True),
      gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999),
      gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1.),
      scale_random_features=kwargs.pop("scale_random_features", False),
      l2_regularization=kwargs.pop("l2_regularization", 1e-6),
      gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"),
      return_gp_cov=kwargs.pop("return_gp_cov", True),
      return_random_features=kwargs.pop("return_random_features", False),
      use_custom_random_features=kwargs.pop("use_custom_random_features", True),
      custom_random_features_initializer=kwargs.pop(
          "custom_random_features_initializer", "random_normal"),
      custom_random_features_activation=kwargs.pop(
          "custom_random_features_activation", None))


def extract_spec_norm_kwargs(kwargs):
  """Extracts spectral normalization configs from a given kwarg."""

  return dict(
      iteration=kwargs.pop("iteration", 1),
      norm_multiplier=kwargs.pop("norm_multiplier", .99))


class PerQueryDenseHead(tf_keras.layers.Layer):
  """Pooling head used for EncT5 style models.

    This module projects each query to use a different projection.

    For a input shape= [bs, num_queries, hidden_size], it projects each query to
    (features). Ending up with shape= [bs, num_queries, features].

    For example, for classification with a few classes, one may use num_queries
    as 1 and features as number of classes. For multilabel classification, one
    may use num_queries as number of classes and features as 2. So each query
    represents a binary classification of one label.
  """

  def __init__(self,
               num_queries: int,
               features: int,
               use_bias: bool = False,
               kernel_initializer: str = "glorot_uniform",
               **kwargs):
    """Initializes the `PerQueryDenseHead`.

    Args:
      num_queries: number of queries (the learnable embeddings in the input
        sequences) from the decoder.
      features: int with numbers of output features. Each query with be
        projected to this number with a different projection.
      use_bias: whether to add a bias to the output.
      kernel_initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self.num_queries = num_queries
    self.features = features

    self.use_bias = use_bias
    self.kernel_initializer = tf_keras.initializers.get(kernel_initializer)

  def build(self, input_shape):
    input_shape = tf.TensorShape(input_shape)
    # Hidden size.
    last_dim = tf.compat.dimension_value(input_shape[-1])

    self.hidden_size = last_dim
    self.kernel = self.add_weight(
        "kernel",
        shape=[self.num_queries, last_dim, self.features],
        initializer=self.kernel_initializer,
        dtype=self.dtype,
        trainable=True)
    if self.use_bias:
      self.bias = self.add_weight(
          "bias",
          shape=[
              self.num_queries,
              self.features,
          ],
          dtype=self.dtype,
          trainable=True)
    else:
      self.bias = None

  def call(self, inputs: tf.Tensor) -> tf.Tensor:
    """Implements call().

    Args:
      inputs: a rank-3 Tensor of shape= [bs, num_queries, hidden_size].

    Returns:
      A Tensor, shape= [batch size, num_queries, features].
    """

    outputs = tf.einsum("bqh,qhf->bqf", inputs, self.kernel)
    if self.use_bias:
      outputs += self.bias
    return outputs

  def get_config(self):
    config = {
        "num_queries":
            self.num_queries,
        "features":
            self.features,
        "kernel_initializer":
            tf_keras.activations.serialize(self.kernel_initializer),
    }
    config.update(super(PerQueryDenseHead, self).get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)