# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """XLNet models.""" # pylint: disable=g-classes-have-attributes from typing import Any, Mapping, Optional, Union import tensorflow as tf, tf_keras from official.nlp.modeling import layers from official.nlp.modeling import networks class XLNetMaskedLM(tf_keras.layers.Layer): """XLNet pretraining head.""" def __init__(self, vocab_size: int, hidden_size: int, initializer: str = 'glorot_uniform', activation: str = 'gelu', name=None, **kwargs): super().__init__(name=name, **kwargs) self._vocab_size = vocab_size self._hidden_size = hidden_size self._initializer = initializer self._activation = activation def build(self, input_shape): self.dense = tf_keras.layers.Dense( units=self._hidden_size, activation=self._activation, kernel_initializer=self._initializer, name='transform/dense') self.layer_norm = tf_keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name='transform/LayerNorm') self.bias = self.add_weight( 'output_bias/bias', shape=(self._vocab_size,), initializer='zeros', trainable=True) super().build(input_shape) def call(self, sequence_data: tf.Tensor, embedding_table: tf.Tensor): lm_data = self.dense(sequence_data) lm_data = self.layer_norm(lm_data) lm_data = tf.matmul(lm_data, embedding_table, transpose_b=True) logits = tf.nn.bias_add(lm_data, self.bias) return logits def get_config(self) -> Mapping[str, Any]: config = { 'vocab_size': self._vocab_size, 'hidden_size': self._hidden_size, 'initializer': self._initializer } base_config = super(XLNetMaskedLM, self).get_config() return dict(list(base_config.items()) + list(config.items())) @tf_keras.utils.register_keras_serializable(package='Text') class XLNetPretrainer(tf_keras.Model): """XLNet-based pretrainer. This is an implementation of the network structure surrounding a Transformer-XL encoder as described in "XLNet: Generalized Autoregressive Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237). Args: network: An XLNet/Transformer-XL based network. This network should output a sequence output and list of `state` tensors. mlm_activation: The activation (if any) to use in the Masked LM network. If None, then no activation will be used. mlm_initializer: The initializer (if any) to use in the masked LM. Defaults to a Glorot uniform initializer. """ def __init__( self, network: Union[tf_keras.layers.Layer, tf_keras.Model], mlm_activation=None, mlm_initializer='glorot_uniform', name: Optional[str] = None, **kwargs): super().__init__(name=name, **kwargs) self._config = { 'network': network, 'mlm_activation': mlm_activation, 'mlm_initializer': mlm_initializer, } self._network = network self._hidden_size = network.get_config()['hidden_size'] self._vocab_size = network.get_config()['vocab_size'] self._activation = mlm_activation self._initializer = mlm_initializer self._masked_lm = XLNetMaskedLM( vocab_size=self._vocab_size, hidden_size=self._hidden_size, initializer=self._initializer) def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks input_word_ids = inputs['input_word_ids'] input_type_ids = inputs['input_type_ids'] masked_tokens = inputs['masked_tokens'] permutation_mask = inputs['permutation_mask'] target_mapping = inputs['target_mapping'] state = inputs.get('state', None) attention_output, state = self._network( input_ids=input_word_ids, segment_ids=input_type_ids, input_mask=None, state=state, permutation_mask=permutation_mask, target_mapping=target_mapping, masked_tokens=masked_tokens) embedding_table = self._network.get_embedding_lookup_table() mlm_outputs = self._masked_lm( sequence_data=attention_output, embedding_table=embedding_table) return mlm_outputs, state def get_config(self) -> Mapping[str, Any]: return self._config @classmethod def from_config(cls, config, custom_objects=None): return cls(**config) @property def checkpoint_items(self): return dict(encoder=self._network) @tf_keras.utils.register_keras_serializable(package='Text') class XLNetClassifier(tf_keras.Model): """Classifier model based on XLNet. This is an implementation of the network structure surrounding a Transformer-XL encoder as described in "XLNet: Generalized Autoregressive Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237). Note: This model does not use utilize the memory mechanism used in the original XLNet Classifier. Args: network: An XLNet/Transformer-XL based network. This network should output a sequence output and list of `state` tensors. num_classes: Number of classes to predict from the classification network. initializer: The initializer (if any) to use in the classification networks. Defaults to a RandomNormal initializer. summary_type: Method used to summarize a sequence into a compact vector. dropout_rate: The dropout probability of the cls head. head_name: Name of the classification head. """ def __init__( self, network: Union[tf_keras.layers.Layer, tf_keras.Model], num_classes: int, initializer: tf_keras.initializers.Initializer = 'random_normal', summary_type: str = 'last', dropout_rate: float = 0.1, head_name: str = 'sentence_prediction', # pytype: disable=annotation-type-mismatch # typed-keras **kwargs): super().__init__(**kwargs) self._network = network self._initializer = initializer self._summary_type = summary_type self._num_classes = num_classes self._config = { 'network': network, 'initializer': initializer, 'num_classes': num_classes, 'summary_type': summary_type, 'dropout_rate': dropout_rate, 'head_name': head_name, } if summary_type == 'last': cls_token_idx = -1 elif summary_type == 'first': cls_token_idx = 0 else: raise ValueError('Invalid summary type provided: %s.' % summary_type) self.classifier = layers.ClassificationHead( inner_dim=network.get_config()['hidden_size'], num_classes=num_classes, initializer=initializer, dropout_rate=dropout_rate, cls_token_idx=cls_token_idx, name=head_name) def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks input_ids = inputs['input_word_ids'] segment_ids = inputs['input_type_ids'] input_mask = tf.cast(inputs['input_mask'], tf.float32) state = inputs.get('mems', None) attention_output, _ = self._network( input_ids=input_ids, segment_ids=segment_ids, input_mask=input_mask, state=state) logits = self.classifier(attention_output) return logits def get_config(self): return self._config @classmethod def from_config(cls, config, custom_objects=None): return cls(**config) @property def checkpoint_items(self): items = dict(encoder=self._network) if hasattr(self.classifier, 'checkpoint_items'): for key, item in self.classifier.checkpoint_items.items(): items['.'.join([self.classifier.name, key])] = item return items @tf_keras.utils.register_keras_serializable(package='Text') class XLNetSpanLabeler(tf_keras.Model): """Span labeler model based on XLNet. This is an implementation of the network structure surrounding a Transformer-XL encoder as described in "XLNet: Generalized Autoregressive Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237). Args: network: A transformer network. This network should output a sequence output and a classification output. Furthermore, it should expose its embedding table via a "get_embedding_table" method. start_n_top: Beam size for span start. end_n_top: Beam size for span end. dropout_rate: The dropout rate for the span labeling layer. span_labeling_activation: The activation for the span labeling head. initializer: The initializer (if any) to use in the span labeling network. Defaults to a Glorot uniform initializer. """ def __init__( self, network: Union[tf_keras.layers.Layer, tf_keras.Model], start_n_top: int = 5, end_n_top: int = 5, dropout_rate: float = 0.1, span_labeling_activation: tf_keras.initializers.Initializer = 'tanh', initializer: tf_keras.initializers.Initializer = 'glorot_uniform', # pytype: disable=annotation-type-mismatch # typed-keras **kwargs): super().__init__(**kwargs) self._config = { 'network': network, 'start_n_top': start_n_top, 'end_n_top': end_n_top, 'dropout_rate': dropout_rate, 'span_labeling_activation': span_labeling_activation, 'initializer': initializer, } network_config = network.get_config() try: input_width = network_config['inner_size'] self._xlnet_base = True except KeyError: # BertEncoder uses 'intermediate_size' due to legacy naming. input_width = network_config['intermediate_size'] self._xlnet_base = False self._network = network self._initializer = initializer self._start_n_top = start_n_top self._end_n_top = end_n_top self._dropout_rate = dropout_rate self._activation = span_labeling_activation self.span_labeling = networks.XLNetSpanLabeling( input_width=input_width, start_n_top=self._start_n_top, end_n_top=self._end_n_top, activation=self._activation, dropout_rate=self._dropout_rate, initializer=self._initializer) def call(self, inputs: Mapping[str, Any]): # pytype: disable=signature-mismatch # overriding-parameter-count-checks input_word_ids = inputs['input_word_ids'] input_type_ids = inputs['input_type_ids'] input_mask = inputs['input_mask'] class_index = inputs['class_index'] paragraph_mask = inputs['paragraph_mask'] start_positions = inputs.get('start_positions', None) if self._xlnet_base: attention_output, _ = self._network( input_ids=input_word_ids, segment_ids=input_type_ids, input_mask=input_mask) else: network_output_dict = self._network(dict( input_word_ids=input_word_ids, input_type_ids=input_type_ids, input_mask=input_mask)) attention_output = network_output_dict['sequence_output'] outputs = self.span_labeling( sequence_data=attention_output, class_index=class_index, paragraph_mask=paragraph_mask, start_positions=start_positions) return outputs @property def checkpoint_items(self): return dict(encoder=self._network) def get_config(self): return self._config @classmethod def from_config(cls, config, custom_objects=None): return cls(**config)