# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Keras layers of XLNet model in TF 2.0.""" import copy import warnings import tensorflow as tf, tf_keras from official.legacy.xlnet import data_utils from official.nlp.modeling import networks def gelu(x): return tf_keras.activations.gelu(x, approximate=True) def _get_initializer(flags): """Get variable initializer.""" if flags.init_method == "uniform": initializer = tf_keras.initializers.RandomUniform( minval=-flags.init_range, maxval=flags.init_range) elif flags.init_method == "normal": initializer = tf_keras.initializers.RandomNormal(stddev=flags.init_std) else: raise ValueError("Initializer {} not supported".format(flags.init_method)) return initializer def rel_shift(x, klen=-1): """Performs relative shift to form the relative attention score.""" x_size = tf.shape(x) x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]]) x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]]) x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1]) return x def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False): """Creates attention mask when single-side context allowed only.""" attn_mask = tf.ones([qlen, qlen], dtype=dtype) mask_u = tf.linalg.band_part(attn_mask, 0, -1) mask_dia = tf.linalg.band_part(attn_mask, 0, 0) attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype) ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) if same_length: mask_l = tf.linalg.band_part(attn_mask, -1, 0) ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) return ret def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None): """cache hidden states into memory.""" if mem_len is None or mem_len == 0: return None else: if reuse_len is not None and reuse_len > 0: curr_out = curr_out[:reuse_len] if prev_mem is None: new_mem = curr_out[-mem_len:] else: new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:] return tf_keras.backend.stop_gradient(new_mem) def is_special_none_tensor(tensor): """Checks if a tensor is a special None Tensor.""" return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 @tf_keras.utils.register_keras_serializable(package="Text") class RelativePositionEncoding(tf_keras.layers.Layer): """Creates a relative positional encoding. This layer creates a relative positional encoding as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" (https://arxiv.org/abs/1901.02860). Rather than an absolute position embedding as in Transformer, this formulation represents position as the relative distance between tokens using sinusoidal positional embeddings. Note: This layer is currently experimental. Attributes: hidden_size: The dimensionality of the input embeddings. """ def __init__(self, hidden_size, **kwargs): super(RelativePositionEncoding, self).__init__(**kwargs) self._hidden_size = hidden_size self._inv_freq = 1.0 / (10000.0**( tf.range(0, self._hidden_size, 2.0) / self._hidden_size)) def call(self, pos_seq, batch_size=None): """Implements call() for the layer. Args: pos_seq: A 1-D `Tensor` batch_size: The optionally provided batch size that tiles the relative positional encoding. Returns: The relative positional encoding of shape: [len(pos_seq), batch_size, hidden_size] if batch_size is provided, else [len(pos_seq), 1, hidden_size]. """ sinusoid_input = tf.einsum("i,d->id", pos_seq, self._inv_freq) pos_emb = tf.concat([tf.sin(sinusoid_input), tf.cos(sinusoid_input)], -1) pos_emb = pos_emb[:, None, :] if batch_size is not None: pos_emb = tf.tile(pos_emb, [1, batch_size, 1]) return pos_emb class RelativeAttention(tf_keras.layers.Layer): """Core calculations for relative attention.""" def __init__(self, dropout_att, scale): super(RelativeAttention, self).__init__() self.scale = scale self.dropout_att = dropout_att def build(self, unused_input_shapes): """Implements build() for the layer.""" self.attention_probs_dropout = tf_keras.layers.Dropout( rate=self.dropout_att) super(RelativeAttention, self).build(unused_input_shapes) def call(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask): """Implements call() for the layer.""" # content based attention score ac = tf.einsum("ibnd,jbnd->ijbn", q_head + r_w_bias, k_head_h) # position based attention score bd = tf.einsum("ibnd,jbnd->ijbn", q_head + r_r_bias, k_head_r) bd = rel_shift(bd, klen=tf.shape(ac)[1]) # segment-based attention score if seg_mat is None: ef = 0 else: ef = tf.einsum("ibnd,snd->isbn", q_head + r_s_bias, seg_embed) tgt_shape = tf.shape(bd) ef = tf.where( tf.broadcast_to(tf.expand_dims(seg_mat, 3), tgt_shape), tf.broadcast_to(ef[:, 1:, :, :], tgt_shape), tf.broadcast_to(ef[:, :1, :, :], tgt_shape)) # merges attention scores and performs masking attn_score = (ac + bd + ef) * self.scale if attn_mask is not None: attn_score = attn_score - 1e30 * attn_mask # attention probability attn_prob = tf.nn.softmax(attn_score, 1) attn_prob = self.attention_probs_dropout(attn_prob) # attention output attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h) return attn_vec class PositionwiseFF(tf_keras.layers.Layer): """Positionwise feed-forward layer.""" def __init__(self, d_model, d_inner, dropout, kernel_initializer, activation_type, **kwargs): super(PositionwiseFF, self).__init__(**kwargs) self.d_model = d_model self.d_inner = d_inner self.dropout = dropout self.activation_type = activation_type self.kernel_initializer = kernel_initializer def build(self, unused_input_shapes): """Implements build() for the layer.""" if self.activation_type == "relu": activation = tf.nn.relu elif self.activation_type == "gelu": activation = gelu else: raise (ValueError("Unsupported activation type {}".format( self.activation_type))) self.inner_projection_layer = ( tf_keras.layers.Dense( units=self.d_inner, activation=activation, kernel_initializer=self.kernel_initializer, name="layer_1")) self.output_projection_layer = ( tf_keras.layers.Dense( units=self.d_model, kernel_initializer=self.kernel_initializer, name="layer_2")) self.output_dropout = tf_keras.layers.Dropout( rate=self.dropout, name="drop_2") self.output_layer_norm = ( tf_keras.layers.LayerNormalization( name="LayerNorm", axis=-1, epsilon=1e-12)) super(PositionwiseFF, self).build(unused_input_shapes) def call(self, inp): """Implements call() for the layer.""" output = self.inner_projection_layer(inp) output = self.output_projection_layer(output) output = self.output_dropout(output) output = self.output_layer_norm(output + inp) return output class EmbeddingLookup(tf_keras.layers.Layer): """Looks up words embeddings for id tensor.""" def __init__(self, n_token, d_embed, initializer, **kwargs): super(EmbeddingLookup, self).__init__(**kwargs) self.n_token = n_token self.d_embed = d_embed self.initializer = initializer def build(self, unused_input_shapes): """Implements build() for the layer.""" self.lookup_table = self.add_weight( "lookup_table", shape=[self.n_token, self.d_embed], initializer=self.initializer, dtype=self.dtype) super(EmbeddingLookup, self).build(unused_input_shapes) def call(self, inputs): return tf.nn.embedding_lookup(self.lookup_table, inputs) class RelativeMultiheadAttention(tf_keras.layers.Layer): """Multi-head attention with relative embedding.""" def __init__(self, d_model, n_head, d_head, dropout, dropout_att, kernel_initializer, **kwargs): super(RelativeMultiheadAttention, self).__init__(**kwargs) self.d_model = d_model self.n_head = n_head self.d_head = d_head self.dropout = dropout self.dropout_att = dropout_att self.initializer = kernel_initializer def build(self, unused_input_shapes): """Implements build() for the layer.""" self.scale = 1.0 / (self.d_head**0.5) self.output_layer_norm = tf_keras.layers.LayerNormalization( name="LayerNorm", axis=-1, epsilon=1e-12) self.kh_projection_layer = self.add_weight( "k/kernel", shape=[self.d_model, self.n_head, self.d_head], initializer=self.initializer) self.vh_projection_layer = self.add_weight( "v/kernel", shape=[self.d_model, self.n_head, self.d_head], initializer=self.initializer) self.kr_projection_layer = self.add_weight( "r/kernel", shape=[self.d_model, self.n_head, self.d_head], initializer=self.initializer) self.qh_projection_layer = self.add_weight( "q/kernel", shape=[self.d_model, self.n_head, self.d_head], initializer=self.initializer) self.relative_attention_layer = RelativeAttention( dropout_att=self.dropout_att, scale=self.scale) self.proj_o = self.add_weight( "o/kernel", shape=[self.d_model, self.n_head, self.d_head], initializer=self.initializer) self.attention_dropout = tf_keras.layers.Dropout(rate=self.dropout) super(RelativeMultiheadAttention, self).build(unused_input_shapes) def call(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h, attn_mask_g, mems, target_mapping): """Implements call() for the layer.""" if mems is not None and mems.shape.ndims > 1: cat = tf.concat([mems, h], 0) else: cat = h # content heads q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.qh_projection_layer) k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.kh_projection_layer) v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.vh_projection_layer) # positional heads k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.kr_projection_layer) # core attention ops attn_vec_h = self.relative_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask_h) # post processing output_h = tf.einsum("ibnd,hnd->ibh", attn_vec_h, self.proj_o) output_h = self.attention_dropout(output_h) output_h = self.output_layer_norm(output_h + h) output_g = None if g is not None: # enable two-stream attention # g-stream q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.qh_projection_layer) if target_mapping is not None: q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping) attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask_g) attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping) else: attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, r_r_bias, r_s_bias, attn_mask_g) # post processing output_g = tf.einsum("ibnd,hnd->ibh", attn_vec_g, self.proj_o) output_g = self.attention_dropout(output_g) output_g = self.output_layer_norm(output_g + g) return (output_h, output_g) class TransformerXLModel(tf_keras.layers.Layer): """Defines a Transformer-XL computation graph with additional support for XLNet.""" def __init__(self, n_token, n_layer, d_model, n_head, d_head, d_inner, dropout, dropout_att, attn_type, bi_data, is_training, initializer, mem_len=None, same_length=False, clamp_len=-1, untie_r=False, use_tpu=True, reuse_len=None, ff_activation="relu", use_cls_mask=False, **kwargs): """Initializes TransformerXLModel. Args: n_token: int, the number of tokens in vocabulary. n_layer: int, the number of layers. d_model: int, the hidden size. n_head: int, the number of attention heads. d_head: int, the dimension size of each attention head. d_inner: int, the hidden size in feed-forward layers. dropout: float, dropout rate. dropout_att: float, dropout rate on attention probabilities. attn_type: str, "uni" or "bi". bi_data: bool, whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. is_training: bool, whether in training mode. initializer: A tf initializer. mem_len: int, the number of tokens to cache. same_length: bool, whether to use the same attention length for each token. clamp_len: int, clamp all relative distances larger than clamp_len. -1 means no clamping. untie_r: bool, whether to untie the biases in attention. use_tpu: bool, whether TPUs are used. reuse_len: int, the number of tokens in the currect batch to be cached and reused in the future. ff_activation: str, "relu" or "gelu". use_cls_mask: bool, whether to introduce cls mask. **kwargs: Other parameters. """ super(TransformerXLModel, self).__init__(**kwargs) warnings.warn( "`TransformerXLModel` is deprecated, please use `XLNetBase` instead", DeprecationWarning, stacklevel=2) self.n_token = n_token self.initializer = initializer self.attn_type = attn_type self.n_layer = n_layer self.d_model = d_model self.n_head = n_head self.d_head = d_head self.d_inner = d_inner self.ff_activation = ff_activation self.untie_r = untie_r self.use_tpu = use_tpu self.dropout = dropout self.dropout_att = dropout_att self.mem_len = mem_len self.reuse_len = reuse_len self.bi_data = bi_data self.clamp_len = clamp_len self.same_length = same_length self.use_cls_mask = use_cls_mask def build(self, unused_input_shapes): """Implements build() for the layer.""" self.tf_float = tf.float32 self.embedding_lookup = EmbeddingLookup( n_token=self.n_token, d_embed=self.d_model, initializer=self.initializer, dtype=self.tf_float, name="word_embedding") self.h_dropout = tf_keras.layers.Dropout(rate=self.dropout) self.g_dropout = tf_keras.layers.Dropout(rate=self.dropout) if self.untie_r: self.r_w_bias = ( self.add_weight( "r_w_bias", shape=[self.n_layer, self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer)) self.r_r_bias = ( self.add_weight( "r_r_bias", shape=[self.n_layer, self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer)) self.r_s_bias = ( self.add_weight( "r_s_bias", shape=[self.n_layer, self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer)) else: self.r_w_bias = ( self.add_weight( "r_w_bias", shape=[self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer)) self.r_r_bias = ( self.add_weight( "r_r_bias", shape=[self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer)) self.r_s_bias = ( self.add_weight( "r_s_bias", [self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer)) self.seg_embed = self.add_weight( "seg_embed", [self.n_layer, 2, self.n_head, self.d_head], dtype=self.tf_float, initializer=self.initializer) self.mask_emb = self.add_weight( "mask_emb/mask_emb", shape=[1, 1, self.d_model], dtype=self.tf_float) self.emb_dropout = tf_keras.layers.Dropout(rate=self.dropout) self.fwd_position_embedding = RelativePositionEncoding(self.d_model) self.bwd_position_embedding = RelativePositionEncoding(self.d_model) self.rel_multihead_layers = [] self.h_positionwise_ffn_layers = [] for i in range(self.n_layer): self.rel_multihead_layers.append( RelativeMultiheadAttention( d_model=self.d_model, dropout=self.dropout, n_head=self.n_head, d_head=self.d_head, dropout_att=self.dropout_att, kernel_initializer=self.initializer, name="layer_%d/rel_attn" % (i))) self.h_positionwise_ffn_layers.append( PositionwiseFF( d_model=self.d_model, d_inner=self.d_inner, dropout=self.dropout, kernel_initializer=self.initializer, activation_type=self.ff_activation, name="layer_%d/ff" % (i))) self.output_dropout = tf_keras.layers.Dropout(rate=self.dropout) super(TransformerXLModel, self).build(unused_input_shapes) def __call__(self, inp_k, seg_id=None, input_mask=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None, **kwargs): # Uses dict to feed inputs into call() in order to keep mems as a python # list. inputs = { "inp_k": inp_k, "seg_id": seg_id, "input_mask": input_mask, "mems": mems, "perm_mask": perm_mask, "target_mapping": target_mapping, "inp_q": inp_q } return super(TransformerXLModel, self).__call__(inputs, **kwargs) def call(self, inputs): """Implements call() for the layer.""" inp_k = inputs["inp_k"] seg_id = inputs["seg_id"] input_mask = inputs["input_mask"] mems = inputs["mems"] perm_mask = inputs["perm_mask"] target_mapping = inputs["target_mapping"] inp_q = inputs["inp_q"] new_mems = [] bsz = tf.shape(inp_k)[1] qlen = inp_k.shape.as_list()[0] mlen = mems[0].shape.as_list()[0] if mems is not None else 0 klen = mlen + qlen ##### Attention mask # causal attention mask if self.attn_type == "uni": attn_mask = _create_mask(qlen, mlen, self.tf_float, self.same_length) # pylint: enable=protected-access attn_mask = attn_mask[:, :, None, None] elif self.attn_type == "bi": attn_mask = None else: raise ValueError("Unsupported attention type: {}".format(self.attn_type)) # data mask: input mask & perm mask if input_mask is not None and perm_mask is not None: data_mask = input_mask[None] + perm_mask elif input_mask is not None and perm_mask is None: data_mask = input_mask[None] elif input_mask is None and perm_mask is not None: data_mask = perm_mask else: data_mask = None if data_mask is not None: # all mems can be attended to mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz], dtype=self.tf_float) data_mask = tf.concat([mems_mask, data_mask], 1) if attn_mask is None: attn_mask = data_mask[:, :, :, None] else: attn_mask += data_mask[:, :, :, None] if attn_mask is not None: attn_mask = tf.cast(attn_mask > 0, dtype=self.tf_float) if attn_mask is not None: non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float) non_tgt_mask = tf.concat( [tf.zeros([qlen, mlen], dtype=self.tf_float), non_tgt_mask], axis=-1) non_tgt_mask = tf.cast( (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=self.tf_float) else: non_tgt_mask = None word_emb_k = self.embedding_lookup(inp_k) if inp_q is not None: if target_mapping is not None: word_emb_q = tf.tile(self.mask_emb, [tf.shape(target_mapping)[0], bsz, 1]) else: inp_q_ext = inp_q[:, :, None] word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k output_h = self.h_dropout(word_emb_k) output_g = None if inp_q is not None: output_g = self.g_dropout(word_emb_q) ##### Segment embedding if seg_id is not None: # Convert `seg_id` to one-hot `seg_mat` mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) cat_id = tf.concat([mem_pad, seg_id], 0) if self.use_cls_mask: # `1` indicates not in the same segment [qlen x klen x bsz] # seg_id: [qlen x bsz] & cat_id: [klen x bsz] cls_mat = tf.logical_or( tf.equal(seg_id, tf.constant([data_utils.SEG_ID_CLS]))[:, None], tf.equal(cat_id, tf.constant([data_utils.SEG_ID_CLS]))[None, :]) seg_mat = tf.equal(seg_id[:, None], cat_id[None, :]) seg_mat = tf.logical_or(cls_mat, seg_mat) else: seg_mat = tf.logical_not(tf.equal(seg_id[:, None], cat_id[None, :])) else: seg_mat = None dtype = self.tf_float freq_seq = tf.range(0, self.d_model, 2.0) if dtype is not None and dtype != tf.float32: freq_seq = tf.cast(freq_seq, dtype=self.dtype) if self.attn_type == "bi": beg, end = klen, -qlen elif self.attn_type == "uni": beg, end = klen, -1 else: raise ValueError("Unknown `attn_type` {}.".format(self.attn_type)) if self.bi_data: fwd_pos_seq = tf.range(beg, end, -1.0) bwd_pos_seq = tf.range(-beg, -end, 1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype) if self.clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len) bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) if bsz is not None: fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz // 2) bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz // 2) else: fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None) bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None) pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1) else: fwd_pos_seq = tf.range(beg, end, -1.0) if dtype is not None and dtype != tf.float32: fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) if self.clamp_len > 0: fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.lamp_len) pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz) pos_emb = self.emb_dropout(pos_emb) if mems is None: mems = [None] * self.n_layer for i in range(self.n_layer): # cache new mems new_mems.append( _cache_mem(output_h, mems[i], self.mem_len, self.reuse_len)) # pylint: enable=protected-access # segment bias if seg_id is None: r_s_bias_i = None seg_embed_i = None else: r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i] seg_embed_i = self.seg_embed[i] ffn_layer = self.h_positionwise_ffn_layers[i] attention_layer = self.rel_multihead_layers[i] output_h, output_g = attention_layer( h=output_h, g=output_g, r=pos_emb, r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i], r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i], seg_mat=seg_mat, r_s_bias=r_s_bias_i, seg_embed=seg_embed_i, attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask, mems=mems[i], target_mapping=target_mapping) output_h = ffn_layer(output_h) if output_g is not None: output_g = ffn_layer(output_g) if inp_q is not None: output = output_g else: output = output_h return output, new_mems, None class PretrainingXLNetModel(tf_keras.Model): """XLNet keras model combined with pretraining LM loss layer. See the original paper: https://arxiv.org/pdf/1906.08237.pdf """ def __init__(self, use_proj, xlnet_config, run_config, use_legacy_mask=True, **kwargs): super(PretrainingXLNetModel, self).__init__(**kwargs) self.run_config = run_config self.initializer = _get_initializer(run_config) self.xlnet_config = copy.deepcopy(xlnet_config) self._use_legacy_mask = use_legacy_mask self.xlnet_model = networks.XLNetBase( vocab_size=self.xlnet_config.n_token, initializer=self.initializer, attention_type="bi", num_layers=self.xlnet_config.n_layer, hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, inner_size=self.xlnet_config.d_inner, two_stream=True, tie_attention_biases=not self.xlnet_config.untie_r, inner_activation=self.xlnet_config.ff_activation, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, memory_length=self.run_config.mem_len, reuse_length=self.run_config.reuse_len, bi_data=self.run_config.bi_data, clamp_length=self.run_config.clamp_len, use_cls_mask=self.run_config.use_cls_mask, name="xlnet_model") self.lmloss_layer = LMLossLayer( vocab_size=self.xlnet_config.n_token, hidden_size=self.xlnet_config.d_model, initializer=self.initializer, tie_weight=True, bi_data=self.run_config.bi_data, use_one_hot=self.run_config.use_tpu, use_proj=use_proj, name="lm_loss") def call(self, features): """Implements call() for the layer.""" input_ids = features["input_ids"] masked_tokens = features["input_q"] seg_ids = features["seg_id"] if self._use_legacy_mask: # Legacy input mask assumes `real` values are 0 and `padding` # values are 1. perm_mask = 1 - features["perm_mask"] else: perm_mask = features["perm_mask"] target_mapping = features["target_mapping"] # target for LM loss target = features["target"] # target mask for LM loss tgt_mask = features["target_mask"] mems = features.get("mems", None) model_output, self.new_mems = self.xlnet_model( input_ids=input_ids, segment_ids=seg_ids, input_mask=None, state=mems, permutation_mask=perm_mask, target_mapping=target_mapping, masked_tokens=masked_tokens) lm_loss, _ = self.lmloss_layer( hidden=model_output, target=target, lookup_table=self.xlnet_model.get_embedding_lookup_table(), target_mask=tgt_mask) self.add_loss(lm_loss) return self.new_mems, model_output class ClassificationXLNetModel(tf_keras.Model): """XLNet keras model combined with classification loss layer. See the original paper: https://arxiv.org/pdf/1906.08237.pdf """ def __init__(self, xlnet_config, run_config, n_class, summary_type, use_legacy_mask=True, **kwargs): super(ClassificationXLNetModel, self).__init__(**kwargs) warnings.warn( "`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`" "instead.", DeprecationWarning, stacklevel=2) self.run_config = run_config self.initializer = _get_initializer(run_config) self.xlnet_config = copy.deepcopy(xlnet_config) self._use_legacy_mask = use_legacy_mask self.xlnet_model = networks.XLNetBase( vocab_size=self.xlnet_config.n_token, initializer=self.initializer, attention_type="bi", num_layers=self.xlnet_config.n_layer, hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, inner_size=self.xlnet_config.d_inner, two_stream=False, tie_attention_biases=not self.xlnet_config.untie_r, inner_activation=self.xlnet_config.ff_activation, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, memory_length=self.run_config.mem_len, reuse_length=self.run_config.reuse_len, bi_data=self.run_config.bi_data, clamp_length=self.run_config.clamp_len, use_cls_mask=False, name="xlnet_model") self.summarization_layer = Summarization( hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, initializer=self.initializer, use_proj=True, summary_type=summary_type, name="sequence_summary") self.cl_loss_layer = ClassificationLossLayer( n_class=n_class, initializer=self.initializer, name="classification") def call(self, features): """Implements call() for the layer.""" batch_size_per_core = tf.shape(features["input_ids"])[0] input_ids = features["input_ids"] segment_ids = features["segment_ids"] if self._use_legacy_mask: # Legacy input mask assumes `real` values are 0 and `padding` # values are 1. input_mask = 1 - features["input_mask"] else: input_mask = features["input_mask"] label = tf.reshape(features["label_ids"], [batch_size_per_core]) mems = features.get("mems", None) attention_output, new_mems = ( self.xlnet_model(input_ids, segment_ids, input_mask, mems)) summary = self.summarization_layer(attention_output) per_example_loss, logits = self.cl_loss_layer(hidden=summary, labels=label) self.add_loss(tf_keras.backend.mean(per_example_loss)) return new_mems, logits class LMLossLayer(tf_keras.layers.Layer): """Layer computing cross entropy loss for language modeling.""" def __init__(self, vocab_size, hidden_size, initializer, tie_weight=False, bi_data=True, use_one_hot=False, use_proj=False, **kwargs): """Constructs LMLoss layer. Args: vocab_size: Number of tokens in vocabulary. hidden_size: The dimension of model hidden state. initializer: Initializer used for parameters. tie_weight: Whether to share weights between embedding lookup layer and next-token prediction layer. bi_data: Whether to use bidirectional input pipeline. Usually set to True during pretraining and False during finetuning. use_one_hot: bool, whether to use one hot encodings. This should be used when TPUs are used. use_proj: bool, whether to add a projection layer before LM prediction. **kwargs: Other parameters. """ super(LMLossLayer, self).__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.initializer = initializer self.tie_weight = tie_weight self.bi_data = bi_data self.use_one_hot = use_one_hot self.use_proj = use_proj def build(self, unused_input_shapes): """Implements build() for the layer.""" if self.use_proj: self.proj_layer = tf_keras.layers.Dense( units=self.hidden_size, kernel_initializer=self.initializer, activation=gelu, name="lm_projection/dense") self.proj_layer_norm = tf_keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name="lm_projection/LayerNorm") if not self.tie_weight: self.softmax_w = self.add_weight( "weight", shape=[self.vocab_size, self.hidden_size], initializer=self.initializer) self.softmax_b = self.add_weight( "bias", shape=[self.vocab_size], initializer=tf.zeros_initializer()) super(LMLossLayer, self).build(unused_input_shapes) def call(self, hidden, target, lookup_table, target_mask): """Implements call() for the layer.""" if self.use_proj: hidden = self.proj_layer_norm(self.proj_layer(hidden)) if self.tie_weight: logits = tf.einsum("ibd,nd->ibn", hidden, lookup_table) + self.softmax_b else: logits = tf.einsum("ibd,nd->ibn", hidden, self.softmax_w) + self.softmax_b if self.use_one_hot: one_hot_target = tf.one_hot(target, self.vocab_size, dtype=logits.dtype) loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) else: loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=target, logits=logits) total_loss = tf.reduce_sum(loss * target_mask) / tf.reduce_sum(target_mask) return total_loss, logits class Summarization(tf_keras.layers.Layer): """The layer to pool the output from XLNet model into a vector.""" def __init__(self, hidden_size, num_attention_heads, head_size, dropout_rate, attention_dropout_rate, initializer, use_proj=True, summary_type="last", **kwargs): """Constructs Summarization layer. Args: hidden_size: int, the dimension of model hidden state. num_attention_heads: int, the number of attention heads. head_size: int, the dimension size of each attention head. dropout_rate: float, dropout rate. attention_dropout_rate: float, dropout rate on attention probabilities. initializer: Initializer used for parameters. use_proj: bool, whether to use projection layer for summarization. summary_type: Method used to summarize a sequence into a compact vector. **kwargs: Other parameters. """ super(Summarization, self).__init__(**kwargs) self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads self.head_size = head_size self.initializer = initializer self.dropout_rate = dropout_rate self.attention_dropout_rate = attention_dropout_rate self.use_proj = use_proj self.summary_type = summary_type def build(self, unused_input_shapes): """Implements build() for the layer.""" if self.use_proj: self.proj_layer = tf_keras.layers.Dense( units=self.hidden_size, kernel_initializer=self.initializer, activation=tf.nn.tanh, name="summary") self.dropout_layer = tf_keras.layers.Dropout(rate=self.dropout_rate) super(Summarization, self).build(unused_input_shapes) def call(self, inputs): """Implements call() for the layer.""" if self.summary_type == "last": summary = inputs[:, -1, :] elif self.summary_type == "first": summary = inputs[:, 0, :] else: raise ValueError("Invalid summary type provided: %s" % self.summary_type) if self.use_proj: summary = self.proj_layer(summary) summary = self.dropout_layer(summary) return summary class ClassificationLossLayer(tf_keras.layers.Layer): """Layer computing cross entropy loss for classification task.""" def __init__(self, n_class, initializer, **kwargs): """Constructs Summarization layer. Args: n_class: Number of tokens in vocabulary. initializer: Initializer used for parameters. **kwargs: Other parameters. """ super(ClassificationLossLayer, self).__init__(**kwargs) self.n_class = n_class self.initializer = initializer def build(self, unused_input_shapes): """Implements build() for the layer.""" self.proj_layer = tf_keras.layers.Dense( units=self.n_class, kernel_initializer=self.initializer, name="logit") super(ClassificationLossLayer, self).build(unused_input_shapes) def call(self, hidden, labels): """Implements call() for the layer.""" logits = self.proj_layer(hidden) one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype) # pytype: disable=attribute-error loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) return loss, logits class QAXLNetModel(tf_keras.Model): """XLNet keras model combined with question answering loss layer. See the original paper: https://arxiv.org/pdf/1906.08237.pdf """ def __init__(self, xlnet_config, run_config, start_n_top, end_n_top, use_legacy_mask=True, **kwargs): super(QAXLNetModel, self).__init__(**kwargs) warnings.warn( "`QAXLNetModel` is deprecated, please use `XLNetSpanLabeler` instead.", DeprecationWarning, stacklevel=2) self.run_config = run_config self.initializer = _get_initializer(run_config) self.xlnet_config = copy.deepcopy(xlnet_config) self._use_legacy_mask = use_legacy_mask self.xlnet_model = networks.XLNetBase( vocab_size=self.xlnet_config.n_token, initializer=self.initializer, attention_type="bi", num_layers=self.xlnet_config.n_layer, hidden_size=self.xlnet_config.d_model, num_attention_heads=self.xlnet_config.n_head, head_size=self.xlnet_config.d_head, inner_size=self.xlnet_config.d_inner, tie_attention_biases=not self.xlnet_config.untie_r, inner_activation=self.xlnet_config.ff_activation, dropout_rate=self.run_config.dropout, attention_dropout_rate=self.run_config.dropout_att, two_stream=False, memory_length=self.run_config.mem_len, reuse_length=self.run_config.reuse_len, bi_data=self.run_config.bi_data, clamp_length=self.run_config.clamp_len, use_cls_mask=False, name="xlnet_model") self.qa_loss_layer = QALossLayer( hidden_size=self.xlnet_config.d_model, start_n_top=start_n_top, end_n_top=end_n_top, initializer=self.initializer, dropout_rate=self.run_config.dropout, name="qa_loss_layer") def call(self, features, training=False): """Implements call() for the layer.""" input_ids = features["input_ids"] segment_ids = features["segment_ids"] if self._use_legacy_mask: # Legacy input mask assumes `real` values are 0 and `padding` # values are 1. input_mask = 1 - features["input_mask"] else: input_mask = features["input_mask"] cls_index = tf.reshape(features["cls_index"], [-1]) p_mask = features["p_mask"] attention_output, new_mems = ( self.xlnet_model(input_ids, segment_ids, input_mask)) if training: loss, logits = self.qa_loss_layer( hidden=attention_output, p_mask=p_mask, cls_index=cls_index, start_positions=features["start_positions"], end_positions=features["end_positions"], is_impossible=features["is_impossible"]) self.add_loss(loss) return new_mems, logits else: results = self.qa_loss_layer( hidden=attention_output, p_mask=p_mask, cls_index=cls_index) return results class QALossLayer(tf_keras.layers.Layer): """Layer computing position and regression loss for question answering task.""" def __init__(self, hidden_size, start_n_top, end_n_top, initializer, dropout_rate, **kwargs): """Constructs Summarization layer. Args: hidden_size: Int, the hidden size. start_n_top: Beam size for span start. end_n_top: Beam size for span end. initializer: Initializer used for parameters. dropout_rate: float, dropout rate. **kwargs: Other parameters. """ super(QALossLayer, self).__init__(**kwargs) self.hidden_size = hidden_size self.start_n_top = start_n_top self.end_n_top = end_n_top self.initializer = initializer self.dropout_rate = dropout_rate def build(self, unused_input_shapes): """Implements build() for the layer.""" self.start_logits_proj_layer = tf_keras.layers.Dense( units=1, kernel_initializer=self.initializer, name="start_logits/dense") self.end_logits_proj_layer0 = tf_keras.layers.Dense( units=self.hidden_size, kernel_initializer=self.initializer, activation=tf.nn.tanh, name="end_logits/dense_0") self.end_logits_proj_layer1 = tf_keras.layers.Dense( units=1, kernel_initializer=self.initializer, name="end_logits/dense_1") self.end_logits_layer_norm = tf_keras.layers.LayerNormalization( axis=-1, epsilon=1e-12, name="end_logits/LayerNorm") self.answer_class_proj_layer0 = tf_keras.layers.Dense( units=self.hidden_size, kernel_initializer=self.initializer, activation=tf.nn.tanh, name="answer_class/dense_0") self.answer_class_proj_layer1 = tf_keras.layers.Dense( units=1, kernel_initializer=self.initializer, use_bias=False, name="answer_class/dense_1") self.ans_feature_dropout = tf_keras.layers.Dropout(rate=self.dropout_rate) super(QALossLayer, self).build(unused_input_shapes) def __call__(self, hidden, p_mask, cls_index, **kwargs): return super(QALossLayer, self).__call__( (hidden, p_mask, cls_index, kwargs)) def call(self, inputs, training=False): """Implements call() for the layer.""" hidden, p_mask, cls_index, kwargs = inputs return_dict = {} seq_len = tf.shape(hidden)[1] hidden = tf.transpose(hidden, [1, 0, 2]) start_logits = self.start_logits_proj_layer(hidden) start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0]) start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask start_log_probs = tf.nn.log_softmax(start_logits_masked, -1) if training: start_positions = kwargs["start_positions"] end_positions = kwargs["end_positions"] is_impossible = kwargs["is_impossible"] start_positions = tf.reshape(start_positions, [-1]) start_index = tf.one_hot( start_positions, depth=seq_len, axis=-1, dtype=tf.float32) start_features = tf.einsum("lbh,bl->bh", hidden, start_index) start_features = tf.tile(start_features[None], [seq_len, 1, 1]) end_logits = self.end_logits_proj_layer0( tf.concat([hidden, start_features], axis=-1)) end_logits = self.end_logits_layer_norm(end_logits) end_logits = self.end_logits_proj_layer1(end_logits) end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0]) end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) else: # during inference, compute the end logits based on beam search start_top_log_probs, start_top_index = tf.nn.top_k( start_log_probs, k=self.start_n_top) start_index = tf.one_hot( start_top_index, depth=seq_len, axis=-1, dtype=tf.float32) start_features = tf.einsum("lbh,bkl->bkh", hidden, start_index) end_input = tf.tile(hidden[:, :, None], [1, 1, self.start_n_top, 1]) start_features = tf.tile(start_features[None], [seq_len, 1, 1, 1]) end_input = tf.concat([end_input, start_features], axis=-1) end_logits = self.end_logits_proj_layer0(end_input) end_logits = tf.reshape(end_logits, [seq_len, -1, self.hidden_size]) end_logits = self.end_logits_layer_norm(end_logits) end_logits = tf.reshape(end_logits, [seq_len, -1, self.start_n_top, self.hidden_size]) end_logits = self.end_logits_proj_layer1(end_logits) end_logits = tf.reshape(end_logits, [seq_len, -1, self.start_n_top]) end_logits = tf.transpose(end_logits, [1, 2, 0]) end_logits_masked = end_logits * ( 1 - p_mask[:, None]) - 1e30 * p_mask[:, None] end_log_probs = tf.nn.log_softmax(end_logits_masked, -1) end_top_log_probs, end_top_index = tf.nn.top_k( end_log_probs, k=self.end_n_top) end_top_log_probs = tf.reshape(end_top_log_probs, [-1, self.start_n_top * self.end_n_top]) end_top_index = tf.reshape(end_top_index, [-1, self.start_n_top * self.end_n_top]) if training: return_dict["start_log_probs"] = start_log_probs return_dict["end_log_probs"] = end_log_probs else: return_dict["start_top_log_probs"] = start_top_log_probs return_dict["start_top_index"] = start_top_index return_dict["end_top_log_probs"] = end_top_log_probs return_dict["end_top_index"] = end_top_index # an additional layer to predict answerability # get the representation of CLS cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32) cls_feature = tf.einsum("lbh,bl->bh", hidden, cls_index) # get the representation of START start_p = tf.nn.softmax(start_logits_masked, axis=-1, name="softmax_start") start_feature = tf.einsum("lbh,bl->bh", hidden, start_p) ans_feature = tf.concat([start_feature, cls_feature], -1) ans_feature = self.answer_class_proj_layer0(ans_feature) ans_feature = self.ans_feature_dropout(ans_feature) cls_logits = self.answer_class_proj_layer1(ans_feature) cls_logits = tf.squeeze(cls_logits, -1) return_dict["cls_logits"] = cls_logits if not training: return return_dict def compute_loss(log_probs, positions): one_hot_positions = tf.one_hot(positions, depth=seq_len, dtype=tf.float32) loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1) loss = tf.reduce_mean(loss) return loss start_loss = compute_loss(start_log_probs, start_positions) end_loss = compute_loss(end_log_probs, end_positions) total_loss = (start_loss + end_loss) * 0.5 is_impossible = tf.reshape(is_impossible, [-1]) regression_loss = tf.nn.sigmoid_cross_entropy_with_logits( labels=is_impossible, logits=cls_logits) regression_loss = tf.reduce_mean(regression_loss) total_loss += regression_loss * 0.5 return total_loss, cls_logits