Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Transformer Encoders. | |
Includes configurations and factory methods. | |
""" | |
import dataclasses | |
from typing import Optional, Sequence, Union | |
import gin | |
import tensorflow as tf, tf_keras | |
from official.modeling import hyperparams | |
from official.modeling import tf_utils | |
from official.nlp.modeling import layers | |
from official.nlp.modeling import networks | |
from official.projects.bigbird import encoder as bigbird_encoder | |
class BertEncoderConfig(hyperparams.Config): | |
"""BERT encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
hidden_activation: str = "gelu" | |
intermediate_size: int = 3072 | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
max_position_embeddings: int = 512 | |
type_vocab_size: int = 2 | |
initializer_range: float = 0.02 | |
embedding_size: Optional[int] = None | |
output_range: Optional[int] = None | |
return_all_encoder_outputs: bool = False | |
return_attention_scores: bool = False | |
# Pre/Post-LN Transformer | |
norm_first: bool = False | |
class FunnelEncoderConfig(hyperparams.Config): | |
"""Funnel encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
max_position_embeddings: int = 512 | |
type_vocab_size: int = 16 | |
inner_dim: int = 3072 | |
hidden_activation: str = "gelu" | |
approx_gelu: bool = True | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
pool_type: str = "max" | |
pool_stride: Union[int, Sequence[Union[int, float]]] = 2 | |
unpool_length: int = 0 | |
initializer_range: float = 0.02 | |
output_range: Optional[int] = None | |
embedding_width: Optional[int] = None | |
embedding_layer: Optional[tf_keras.layers.Layer] = None | |
norm_first: bool = False | |
share_rezero: bool = False | |
append_dense_inputs: bool = False | |
transformer_cls: str = "TransformerEncoderBlock" | |
class MobileBertEncoderConfig(hyperparams.Config): | |
"""MobileBERT encoder configuration. | |
Attributes: | |
word_vocab_size: number of words in the vocabulary. | |
word_embed_size: word embedding size. | |
type_vocab_size: number of word types. | |
max_sequence_length: maximum length of input sequence. | |
num_blocks: number of transformer block in the encoder model. | |
hidden_size: the hidden size for the transformer block. | |
num_attention_heads: number of attention heads in the transformer block. | |
intermediate_size: the size of the "intermediate" (a.k.a., feed forward) | |
layer. | |
hidden_activation: the non-linear activation function to apply to the | |
output of the intermediate/feed-forward layer. | |
hidden_dropout_prob: dropout probability for the hidden layers. | |
attention_probs_dropout_prob: dropout probability of the attention | |
probabilities. | |
intra_bottleneck_size: the size of bottleneck. | |
initializer_range: The stddev of the truncated_normal_initializer for | |
initializing all weight matrices. | |
use_bottleneck_attention: Use attention inputs from the bottleneck | |
transformation. If true, the following `key_query_shared_bottleneck` | |
will be ignored. | |
key_query_shared_bottleneck: whether to share linear transformation for keys | |
and queries. | |
num_feedforward_networks: number of stacked feed-forward networks. | |
normalization_type: the type of normalization_type, only 'no_norm' and | |
'layer_norm' are supported. 'no_norm' represents the element-wise linear | |
transformation for the student model, as suggested by the original | |
MobileBERT paper. 'layer_norm' is used for the teacher model. | |
classifier_activation: if using the tanh activation for the final | |
representation of the [CLS] token in fine-tuning. | |
""" | |
word_vocab_size: int = 30522 | |
word_embed_size: int = 128 | |
type_vocab_size: int = 2 | |
max_sequence_length: int = 512 | |
num_blocks: int = 24 | |
hidden_size: int = 512 | |
num_attention_heads: int = 4 | |
intermediate_size: int = 4096 | |
hidden_activation: str = "gelu" | |
hidden_dropout_prob: float = 0.1 | |
attention_probs_dropout_prob: float = 0.1 | |
intra_bottleneck_size: int = 1024 | |
initializer_range: float = 0.02 | |
use_bottleneck_attention: bool = False | |
key_query_shared_bottleneck: bool = False | |
num_feedforward_networks: int = 1 | |
normalization_type: str = "layer_norm" | |
classifier_activation: bool = True | |
input_mask_dtype: str = "int32" | |
class AlbertEncoderConfig(hyperparams.Config): | |
"""ALBERT encoder configuration.""" | |
vocab_size: int = 30000 | |
embedding_width: int = 128 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
hidden_activation: str = "gelu" | |
intermediate_size: int = 3072 | |
dropout_rate: float = 0.0 | |
attention_dropout_rate: float = 0.0 | |
max_position_embeddings: int = 512 | |
type_vocab_size: int = 2 | |
initializer_range: float = 0.02 | |
class BigBirdEncoderConfig(hyperparams.Config): | |
"""BigBird encoder configuration.""" | |
vocab_size: int = 50358 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
hidden_activation: str = "gelu" | |
intermediate_size: int = 3072 | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
# Pre/Post-LN Transformer | |
norm_first: bool = False | |
max_position_embeddings: int = 4096 | |
num_rand_blocks: int = 3 | |
block_size: int = 64 | |
type_vocab_size: int = 16 | |
initializer_range: float = 0.02 | |
embedding_width: Optional[int] = None | |
use_gradient_checkpointing: bool = False | |
class KernelEncoderConfig(hyperparams.Config): | |
"""Linear encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
hidden_activation: str = "gelu" | |
intermediate_size: int = 3072 | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
# Pre/Post-LN Transformer | |
norm_first: bool = False | |
max_position_embeddings: int = 512 | |
type_vocab_size: int = 2 | |
initializer_range: float = 0.02 | |
embedding_size: Optional[int] = None | |
feature_transform: str = "exp" | |
num_random_features: int = 256 | |
redraw: bool = False | |
is_short_seq: bool = False | |
begin_kernel: int = 0 | |
scale: Optional[float] = None | |
class ReuseEncoderConfig(hyperparams.Config): | |
"""Reuse encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
hidden_activation: str = "gelu" | |
intermediate_size: int = 3072 | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
max_position_embeddings: int = 512 | |
type_vocab_size: int = 2 | |
initializer_range: float = 0.02 | |
embedding_size: Optional[int] = None | |
output_range: Optional[int] = None | |
return_all_encoder_outputs: bool = False | |
# Pre/Post-LN Transformer | |
norm_first: bool = False | |
# Reuse transformer | |
reuse_attention: int = -1 | |
use_relative_pe: bool = False | |
pe_max_seq_length: int = 512 | |
max_reuse_layer_idx: int = 6 | |
class XLNetEncoderConfig(hyperparams.Config): | |
"""XLNet encoder configuration.""" | |
vocab_size: int = 32000 | |
num_layers: int = 24 | |
hidden_size: int = 1024 | |
num_attention_heads: int = 16 | |
head_size: int = 64 | |
inner_size: int = 4096 | |
inner_activation: str = "gelu" | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
attention_type: str = "bi" | |
bi_data: bool = False | |
tie_attention_biases: bool = False | |
memory_length: int = 0 | |
same_length: bool = False | |
clamp_length: int = -1 | |
reuse_length: int = 0 | |
use_cls_mask: bool = False | |
embedding_width: int = 1024 | |
initializer_range: float = 0.02 | |
two_stream: bool = False | |
class QueryBertConfig(hyperparams.Config): | |
"""Query BERT encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
hidden_activation: str = "gelu" | |
intermediate_size: int = 3072 | |
dropout_rate: float = 0.1 | |
attention_dropout_rate: float = 0.1 | |
max_position_embeddings: int = 512 | |
type_vocab_size: int = 2 | |
initializer_range: float = 0.02 | |
embedding_size: Optional[int] = None | |
output_range: Optional[int] = None | |
return_all_encoder_outputs: bool = False | |
return_attention_scores: bool = False | |
# Pre/Post-LN Transformer | |
norm_first: bool = False | |
class FNetEncoderConfig(hyperparams.Config): | |
"""FNet encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 12 | |
num_attention_heads: int = 12 | |
inner_activation: str = "gelu" | |
inner_dim: int = 3072 | |
output_dropout: float = 0.1 | |
attention_dropout: float = 0.1 | |
max_sequence_length: int = 512 | |
type_vocab_size: int = 2 | |
initializer_range: float = 0.02 | |
embedding_width: Optional[int] = None | |
output_range: Optional[int] = None | |
norm_first: bool = False | |
use_fft: bool = False | |
attention_layers: Sequence[int] = () | |
class SparseMixerEncoderConfig(hyperparams.Config): | |
"""SparseMixer encoder configuration.""" | |
vocab_size: int = 30522 | |
hidden_size: int = 768 | |
num_layers: int = 14 | |
moe_layers: Sequence[int] = (5, 6, 7, 8) | |
attention_layers: Sequence[int] = (10, 11, 12, 13) | |
num_experts: int = 16 | |
train_capacity_factor: float = 1. | |
eval_capacity_factor: float = 1. | |
examples_per_group: float = 1. | |
use_fft: bool = False | |
num_attention_heads: int = 8 | |
max_sequence_length: int = 512 | |
type_vocab_size: int = 2 | |
inner_dim: int = 3072 | |
inner_activation: str = "gelu" | |
output_dropout: float = 0.1 | |
attention_dropout: float = 0.1 | |
initializer_range: float = 0.02 | |
output_range: Optional[int] = None | |
embedding_width: Optional[int] = None | |
norm_first: bool = False | |
class EncoderConfig(hyperparams.OneOfConfig): | |
"""Encoder configuration.""" | |
type: Optional[str] = "bert" | |
albert: AlbertEncoderConfig = dataclasses.field( | |
default_factory=AlbertEncoderConfig | |
) | |
bert: BertEncoderConfig = dataclasses.field(default_factory=BertEncoderConfig) | |
bert_v2: BertEncoderConfig = dataclasses.field( | |
default_factory=BertEncoderConfig | |
) | |
bigbird: BigBirdEncoderConfig = dataclasses.field( | |
default_factory=BigBirdEncoderConfig | |
) | |
funnel: FunnelEncoderConfig = dataclasses.field( | |
default_factory=FunnelEncoderConfig | |
) | |
kernel: KernelEncoderConfig = dataclasses.field( | |
default_factory=KernelEncoderConfig | |
) | |
mobilebert: MobileBertEncoderConfig = dataclasses.field( | |
default_factory=MobileBertEncoderConfig | |
) | |
reuse: ReuseEncoderConfig = dataclasses.field( | |
default_factory=ReuseEncoderConfig | |
) | |
xlnet: XLNetEncoderConfig = dataclasses.field( | |
default_factory=XLNetEncoderConfig | |
) | |
query_bert: QueryBertConfig = dataclasses.field( | |
default_factory=QueryBertConfig | |
) | |
fnet: FNetEncoderConfig = dataclasses.field(default_factory=FNetEncoderConfig) | |
sparse_mixer: SparseMixerEncoderConfig = dataclasses.field( | |
default_factory=SparseMixerEncoderConfig | |
) | |
# If `any` is used, the encoder building relies on any.BUILDER. | |
any: hyperparams.Config = dataclasses.field( | |
default_factory=hyperparams.Config | |
) | |
def build_encoder(config: EncoderConfig, | |
embedding_layer: Optional[tf_keras.layers.Layer] = None, | |
encoder_cls=None, | |
bypass_config: bool = False): | |
"""Instantiate a Transformer encoder network from EncoderConfig. | |
Args: | |
config: the one-of encoder config, which provides encoder parameters of a | |
chosen encoder. | |
embedding_layer: an external embedding layer passed to the encoder. | |
encoder_cls: an external encoder cls not included in the supported encoders, | |
usually used by gin.configurable. | |
bypass_config: whether to ignore config instance to create the object with | |
`encoder_cls`. | |
Returns: | |
An encoder instance. | |
""" | |
if bypass_config: | |
return encoder_cls() | |
encoder_type = config.type | |
encoder_cfg = config.get() | |
if encoder_cls and encoder_cls.__name__ == "EncoderScaffold": | |
embedding_cfg = dict( | |
vocab_size=encoder_cfg.vocab_size, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
max_seq_length=encoder_cfg.max_position_embeddings, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
dropout_rate=encoder_cfg.dropout_rate, | |
) | |
hidden_cfg = dict( | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
intermediate_activation=tf_utils.get_activation( | |
encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
) | |
kwargs = dict( | |
embedding_cfg=embedding_cfg, | |
hidden_cfg=hidden_cfg, | |
num_hidden_instances=encoder_cfg.num_layers, | |
pooled_output_dim=encoder_cfg.hidden_size, | |
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs, | |
dict_outputs=True) | |
return encoder_cls(**kwargs) | |
if encoder_type == "any": | |
encoder = encoder_cfg.BUILDER(encoder_cfg) | |
if not isinstance(encoder, | |
(tf.Module, tf_keras.Model, tf_keras.layers.Layer)): | |
raise ValueError("The BUILDER returns an unexpected instance. The " | |
"`build_encoder` should returns a tf.Module, " | |
"tf_keras.Model or tf_keras.layers.Layer. However, " | |
f"we get {encoder.__class__}") | |
return encoder | |
if encoder_type == "mobilebert": | |
return networks.MobileBERTEncoder( | |
word_vocab_size=encoder_cfg.word_vocab_size, | |
word_embed_size=encoder_cfg.word_embed_size, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
max_sequence_length=encoder_cfg.max_sequence_length, | |
num_blocks=encoder_cfg.num_blocks, | |
hidden_size=encoder_cfg.hidden_size, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
intermediate_act_fn=encoder_cfg.hidden_activation, | |
hidden_dropout_prob=encoder_cfg.hidden_dropout_prob, | |
attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob, | |
intra_bottleneck_size=encoder_cfg.intra_bottleneck_size, | |
initializer_range=encoder_cfg.initializer_range, | |
use_bottleneck_attention=encoder_cfg.use_bottleneck_attention, | |
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck, | |
num_feedforward_networks=encoder_cfg.num_feedforward_networks, | |
normalization_type=encoder_cfg.normalization_type, | |
classifier_activation=encoder_cfg.classifier_activation, | |
input_mask_dtype=encoder_cfg.input_mask_dtype) | |
if encoder_type == "albert": | |
return networks.AlbertEncoder( | |
vocab_size=encoder_cfg.vocab_size, | |
embedding_width=encoder_cfg.embedding_width, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
max_sequence_length=encoder_cfg.max_position_embeddings, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
intermediate_size=encoder_cfg.intermediate_size, | |
activation=tf_utils.get_activation(encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
dict_outputs=True) | |
if encoder_type == "bigbird": | |
# TODO(frederickliu): Support use_gradient_checkpointing and update | |
# experiments to use the EncoderScaffold only. | |
if encoder_cfg.use_gradient_checkpointing: | |
return bigbird_encoder.BigBirdEncoder( | |
vocab_size=encoder_cfg.vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
activation=tf_utils.get_activation(encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
num_rand_blocks=encoder_cfg.num_rand_blocks, | |
block_size=encoder_cfg.block_size, | |
max_position_embeddings=encoder_cfg.max_position_embeddings, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
embedding_width=encoder_cfg.embedding_width, | |
use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing) | |
embedding_cfg = dict( | |
vocab_size=encoder_cfg.vocab_size, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
max_seq_length=encoder_cfg.max_position_embeddings, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
dropout_rate=encoder_cfg.dropout_rate) | |
attention_cfg = dict( | |
num_heads=encoder_cfg.num_attention_heads, | |
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
max_rand_mask_length=encoder_cfg.max_position_embeddings, | |
num_rand_blocks=encoder_cfg.num_rand_blocks, | |
from_block_size=encoder_cfg.block_size, | |
to_block_size=encoder_cfg.block_size, | |
) | |
hidden_cfg = dict( | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
intermediate_activation=tf_utils.get_activation( | |
encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
norm_first=encoder_cfg.norm_first, | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
attention_cls=layers.BigBirdAttention, | |
attention_cfg=attention_cfg) | |
kwargs = dict( | |
embedding_cfg=embedding_cfg, | |
hidden_cls=layers.TransformerScaffold, | |
hidden_cfg=hidden_cfg, | |
num_hidden_instances=encoder_cfg.num_layers, | |
mask_cls=layers.BigBirdMasks, | |
mask_cfg=dict(block_size=encoder_cfg.block_size), | |
pooled_output_dim=encoder_cfg.hidden_size, | |
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
return_all_layer_outputs=False, | |
dict_outputs=True, | |
layer_idx_as_attention_seed=True) | |
return networks.EncoderScaffold(**kwargs) | |
if encoder_type == "funnel": | |
if encoder_cfg.hidden_activation == "gelu": | |
activation = tf_utils.get_activation( | |
encoder_cfg.hidden_activation, | |
approximate=encoder_cfg.approx_gelu) | |
else: | |
activation = tf_utils.get_activation(encoder_cfg.hidden_activation) | |
return networks.FunnelTransformerEncoder( | |
vocab_size=encoder_cfg.vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
max_sequence_length=encoder_cfg.max_position_embeddings, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
inner_dim=encoder_cfg.inner_dim, | |
inner_activation=activation, | |
output_dropout=encoder_cfg.dropout_rate, | |
attention_dropout=encoder_cfg.attention_dropout_rate, | |
pool_type=encoder_cfg.pool_type, | |
pool_stride=encoder_cfg.pool_stride, | |
unpool_length=encoder_cfg.unpool_length, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
output_range=encoder_cfg.output_range, | |
embedding_width=encoder_cfg.embedding_width, | |
embedding_layer=embedding_layer, | |
norm_first=encoder_cfg.norm_first, | |
share_rezero=encoder_cfg.share_rezero, | |
append_dense_inputs=encoder_cfg.append_dense_inputs, | |
transformer_cls=encoder_cfg.transformer_cls, | |
) | |
if encoder_type == "kernel": | |
embedding_cfg = dict( | |
vocab_size=encoder_cfg.vocab_size, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
max_seq_length=encoder_cfg.max_position_embeddings, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
dropout_rate=encoder_cfg.dropout_rate) | |
attention_cfg = dict( | |
num_heads=encoder_cfg.num_attention_heads, | |
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads), | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
feature_transform=encoder_cfg.feature_transform, | |
num_random_features=encoder_cfg.num_random_features, | |
redraw=encoder_cfg.redraw, | |
is_short_seq=encoder_cfg.is_short_seq, | |
begin_kernel=encoder_cfg.begin_kernel, | |
scale=encoder_cfg.scale, | |
) | |
hidden_cfg = dict( | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
intermediate_activation=tf_utils.get_activation( | |
encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
norm_first=encoder_cfg.norm_first, | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
attention_cls=layers.KernelAttention, | |
attention_cfg=attention_cfg) | |
kwargs = dict( | |
embedding_cfg=embedding_cfg, | |
hidden_cls=layers.TransformerScaffold, | |
hidden_cfg=hidden_cfg, | |
num_hidden_instances=encoder_cfg.num_layers, | |
mask_cls=layers.KernelMask, | |
pooled_output_dim=encoder_cfg.hidden_size, | |
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
return_all_layer_outputs=False, | |
dict_outputs=True, | |
layer_idx_as_attention_seed=True) | |
return networks.EncoderScaffold(**kwargs) | |
if encoder_type == "xlnet": | |
return networks.XLNetBase( | |
vocab_size=encoder_cfg.vocab_size, | |
num_layers=encoder_cfg.num_layers, | |
hidden_size=encoder_cfg.hidden_size, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
head_size=encoder_cfg.head_size, | |
inner_size=encoder_cfg.inner_size, | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
attention_type=encoder_cfg.attention_type, | |
bi_data=encoder_cfg.bi_data, | |
two_stream=encoder_cfg.two_stream, | |
tie_attention_biases=encoder_cfg.tie_attention_biases, | |
memory_length=encoder_cfg.memory_length, | |
clamp_length=encoder_cfg.clamp_length, | |
reuse_length=encoder_cfg.reuse_length, | |
inner_activation=encoder_cfg.inner_activation, | |
use_cls_mask=encoder_cfg.use_cls_mask, | |
embedding_width=encoder_cfg.embedding_width, | |
initializer=tf_keras.initializers.RandomNormal( | |
stddev=encoder_cfg.initializer_range)) | |
if encoder_type == "reuse": | |
embedding_cfg = dict( | |
vocab_size=encoder_cfg.vocab_size, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
max_seq_length=encoder_cfg.max_position_embeddings, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
dropout_rate=encoder_cfg.dropout_rate) | |
hidden_cfg = dict( | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
inner_dim=encoder_cfg.intermediate_size, | |
inner_activation=tf_utils.get_activation( | |
encoder_cfg.hidden_activation), | |
output_dropout=encoder_cfg.dropout_rate, | |
attention_dropout=encoder_cfg.attention_dropout_rate, | |
norm_first=encoder_cfg.norm_first, | |
kernel_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
reuse_attention=encoder_cfg.reuse_attention, | |
use_relative_pe=encoder_cfg.use_relative_pe, | |
pe_max_seq_length=encoder_cfg.pe_max_seq_length, | |
max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx) | |
kwargs = dict( | |
embedding_cfg=embedding_cfg, | |
hidden_cls=layers.ReuseTransformer, | |
hidden_cfg=hidden_cfg, | |
num_hidden_instances=encoder_cfg.num_layers, | |
pooled_output_dim=encoder_cfg.hidden_size, | |
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
return_all_layer_outputs=False, | |
dict_outputs=True, | |
feed_layer_idx=True, | |
recursive=True) | |
return networks.EncoderScaffold(**kwargs) | |
if encoder_type == "query_bert": | |
embedding_layer = layers.FactorizedEmbedding( | |
vocab_size=encoder_cfg.vocab_size, | |
embedding_width=encoder_cfg.embedding_size, | |
output_dim=encoder_cfg.hidden_size, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
name="word_embeddings") | |
return networks.BertEncoderV2( | |
vocab_size=encoder_cfg.vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
activation=tf_utils.get_activation(encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
max_sequence_length=encoder_cfg.max_position_embeddings, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
output_range=encoder_cfg.output_range, | |
embedding_layer=embedding_layer, | |
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, | |
return_attention_scores=encoder_cfg.return_attention_scores, | |
dict_outputs=True, | |
norm_first=encoder_cfg.norm_first) | |
if encoder_type == "fnet": | |
return networks.FNet( | |
vocab_size=encoder_cfg.vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
inner_dim=encoder_cfg.inner_dim, | |
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation), | |
output_dropout=encoder_cfg.output_dropout, | |
attention_dropout=encoder_cfg.attention_dropout, | |
max_sequence_length=encoder_cfg.max_sequence_length, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
output_range=encoder_cfg.output_range, | |
embedding_width=encoder_cfg.embedding_width, | |
embedding_layer=embedding_layer, | |
norm_first=encoder_cfg.norm_first, | |
use_fft=encoder_cfg.use_fft, | |
attention_layers=encoder_cfg.attention_layers) | |
if encoder_type == "sparse_mixer": | |
return networks.SparseMixer( | |
vocab_size=encoder_cfg.vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
moe_layers=encoder_cfg.moe_layers, | |
attention_layers=encoder_cfg.attention_layers, | |
num_experts=encoder_cfg.num_experts, | |
train_capacity_factor=encoder_cfg.train_capacity_factor, | |
eval_capacity_factor=encoder_cfg.eval_capacity_factor, | |
examples_per_group=encoder_cfg.examples_per_group, | |
use_fft=encoder_cfg.use_fft, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
max_sequence_length=encoder_cfg.max_sequence_length, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
inner_dim=encoder_cfg.inner_dim, | |
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation), | |
output_dropout=encoder_cfg.output_dropout, | |
attention_dropout=encoder_cfg.attention_dropout, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
output_range=encoder_cfg.output_range, | |
embedding_width=encoder_cfg.embedding_width, | |
norm_first=encoder_cfg.norm_first, | |
embedding_layer=embedding_layer) | |
bert_encoder_cls = networks.BertEncoder | |
if encoder_type == "bert_v2": | |
bert_encoder_cls = networks.BertEncoderV2 | |
# Uses the default BERTEncoder configuration schema to create the encoder. | |
# If it does not match, please add a switch branch by the encoder type. | |
return bert_encoder_cls( | |
vocab_size=encoder_cfg.vocab_size, | |
hidden_size=encoder_cfg.hidden_size, | |
num_layers=encoder_cfg.num_layers, | |
num_attention_heads=encoder_cfg.num_attention_heads, | |
intermediate_size=encoder_cfg.intermediate_size, | |
activation=tf_utils.get_activation(encoder_cfg.hidden_activation), | |
dropout_rate=encoder_cfg.dropout_rate, | |
attention_dropout_rate=encoder_cfg.attention_dropout_rate, | |
max_sequence_length=encoder_cfg.max_position_embeddings, | |
type_vocab_size=encoder_cfg.type_vocab_size, | |
initializer=tf_keras.initializers.TruncatedNormal( | |
stddev=encoder_cfg.initializer_range), | |
output_range=encoder_cfg.output_range, | |
embedding_width=encoder_cfg.embedding_size, | |
embedding_layer=embedding_layer, | |
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs, | |
return_attention_scores=encoder_cfg.return_attention_scores, | |
dict_outputs=True, | |
norm_first=encoder_cfg.norm_first) | |