deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer Encoders.
Includes configurations and factory methods.
"""
import dataclasses
from typing import Optional, Sequence, Union
import gin
import tensorflow as tf, tf_keras
from official.modeling import hyperparams
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import networks
from official.projects.bigbird import encoder as bigbird_encoder
@dataclasses.dataclass
class BertEncoderConfig(hyperparams.Config):
"""BERT encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
return_attention_scores: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
@dataclasses.dataclass
class FunnelEncoderConfig(hyperparams.Config):
"""Funnel encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
max_position_embeddings: int = 512
type_vocab_size: int = 16
inner_dim: int = 3072
hidden_activation: str = "gelu"
approx_gelu: bool = True
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
pool_type: str = "max"
pool_stride: Union[int, Sequence[Union[int, float]]] = 2
unpool_length: int = 0
initializer_range: float = 0.02
output_range: Optional[int] = None
embedding_width: Optional[int] = None
embedding_layer: Optional[tf_keras.layers.Layer] = None
norm_first: bool = False
share_rezero: bool = False
append_dense_inputs: bool = False
transformer_cls: str = "TransformerEncoderBlock"
@dataclasses.dataclass
class MobileBertEncoderConfig(hyperparams.Config):
"""MobileBERT encoder configuration.
Attributes:
word_vocab_size: number of words in the vocabulary.
word_embed_size: word embedding size.
type_vocab_size: number of word types.
max_sequence_length: maximum length of input sequence.
num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed forward)
layer.
hidden_activation: the non-linear activation function to apply to the
output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention
probabilities.
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: whether to share linear transformation for keys
and queries.
num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: if using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
"""
word_vocab_size: int = 30522
word_embed_size: int = 128
type_vocab_size: int = 2
max_sequence_length: int = 512
num_blocks: int = 24
hidden_size: int = 512
num_attention_heads: int = 4
intermediate_size: int = 4096
hidden_activation: str = "gelu"
hidden_dropout_prob: float = 0.1
attention_probs_dropout_prob: float = 0.1
intra_bottleneck_size: int = 1024
initializer_range: float = 0.02
use_bottleneck_attention: bool = False
key_query_shared_bottleneck: bool = False
num_feedforward_networks: int = 1
normalization_type: str = "layer_norm"
classifier_activation: bool = True
input_mask_dtype: str = "int32"
@dataclasses.dataclass
class AlbertEncoderConfig(hyperparams.Config):
"""ALBERT encoder configuration."""
vocab_size: int = 30000
embedding_width: int = 128
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.0
attention_dropout_rate: float = 0.0
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
@dataclasses.dataclass
class BigBirdEncoderConfig(hyperparams.Config):
"""BigBird encoder configuration."""
vocab_size: int = 50358
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 4096
num_rand_blocks: int = 3
block_size: int = 64
type_vocab_size: int = 16
initializer_range: float = 0.02
embedding_width: Optional[int] = None
use_gradient_checkpointing: bool = False
@dataclasses.dataclass
class KernelEncoderConfig(hyperparams.Config):
"""Linear encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
# Pre/Post-LN Transformer
norm_first: bool = False
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
feature_transform: str = "exp"
num_random_features: int = 256
redraw: bool = False
is_short_seq: bool = False
begin_kernel: int = 0
scale: Optional[float] = None
@dataclasses.dataclass
class ReuseEncoderConfig(hyperparams.Config):
"""Reuse encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
# Reuse transformer
reuse_attention: int = -1
use_relative_pe: bool = False
pe_max_seq_length: int = 512
max_reuse_layer_idx: int = 6
@dataclasses.dataclass
class XLNetEncoderConfig(hyperparams.Config):
"""XLNet encoder configuration."""
vocab_size: int = 32000
num_layers: int = 24
hidden_size: int = 1024
num_attention_heads: int = 16
head_size: int = 64
inner_size: int = 4096
inner_activation: str = "gelu"
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
attention_type: str = "bi"
bi_data: bool = False
tie_attention_biases: bool = False
memory_length: int = 0
same_length: bool = False
clamp_length: int = -1
reuse_length: int = 0
use_cls_mask: bool = False
embedding_width: int = 1024
initializer_range: float = 0.02
two_stream: bool = False
@dataclasses.dataclass
class QueryBertConfig(hyperparams.Config):
"""Query BERT encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
hidden_activation: str = "gelu"
intermediate_size: int = 3072
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_size: Optional[int] = None
output_range: Optional[int] = None
return_all_encoder_outputs: bool = False
return_attention_scores: bool = False
# Pre/Post-LN Transformer
norm_first: bool = False
@dataclasses.dataclass
class FNetEncoderConfig(hyperparams.Config):
"""FNet encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 12
num_attention_heads: int = 12
inner_activation: str = "gelu"
inner_dim: int = 3072
output_dropout: float = 0.1
attention_dropout: float = 0.1
max_sequence_length: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
embedding_width: Optional[int] = None
output_range: Optional[int] = None
norm_first: bool = False
use_fft: bool = False
attention_layers: Sequence[int] = ()
@dataclasses.dataclass
class SparseMixerEncoderConfig(hyperparams.Config):
"""SparseMixer encoder configuration."""
vocab_size: int = 30522
hidden_size: int = 768
num_layers: int = 14
moe_layers: Sequence[int] = (5, 6, 7, 8)
attention_layers: Sequence[int] = (10, 11, 12, 13)
num_experts: int = 16
train_capacity_factor: float = 1.
eval_capacity_factor: float = 1.
examples_per_group: float = 1.
use_fft: bool = False
num_attention_heads: int = 8
max_sequence_length: int = 512
type_vocab_size: int = 2
inner_dim: int = 3072
inner_activation: str = "gelu"
output_dropout: float = 0.1
attention_dropout: float = 0.1
initializer_range: float = 0.02
output_range: Optional[int] = None
embedding_width: Optional[int] = None
norm_first: bool = False
@dataclasses.dataclass
class EncoderConfig(hyperparams.OneOfConfig):
"""Encoder configuration."""
type: Optional[str] = "bert"
albert: AlbertEncoderConfig = dataclasses.field(
default_factory=AlbertEncoderConfig
)
bert: BertEncoderConfig = dataclasses.field(default_factory=BertEncoderConfig)
bert_v2: BertEncoderConfig = dataclasses.field(
default_factory=BertEncoderConfig
)
bigbird: BigBirdEncoderConfig = dataclasses.field(
default_factory=BigBirdEncoderConfig
)
funnel: FunnelEncoderConfig = dataclasses.field(
default_factory=FunnelEncoderConfig
)
kernel: KernelEncoderConfig = dataclasses.field(
default_factory=KernelEncoderConfig
)
mobilebert: MobileBertEncoderConfig = dataclasses.field(
default_factory=MobileBertEncoderConfig
)
reuse: ReuseEncoderConfig = dataclasses.field(
default_factory=ReuseEncoderConfig
)
xlnet: XLNetEncoderConfig = dataclasses.field(
default_factory=XLNetEncoderConfig
)
query_bert: QueryBertConfig = dataclasses.field(
default_factory=QueryBertConfig
)
fnet: FNetEncoderConfig = dataclasses.field(default_factory=FNetEncoderConfig)
sparse_mixer: SparseMixerEncoderConfig = dataclasses.field(
default_factory=SparseMixerEncoderConfig
)
# If `any` is used, the encoder building relies on any.BUILDER.
any: hyperparams.Config = dataclasses.field(
default_factory=hyperparams.Config
)
@gin.configurable
def build_encoder(config: EncoderConfig,
embedding_layer: Optional[tf_keras.layers.Layer] = None,
encoder_cls=None,
bypass_config: bool = False):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
if bypass_config:
return encoder_cls()
encoder_type = config.type
encoder_cfg = config.get()
if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
dict_outputs=True)
return encoder_cls(**kwargs)
if encoder_type == "any":
encoder = encoder_cfg.BUILDER(encoder_cfg)
if not isinstance(encoder,
(tf.Module, tf_keras.Model, tf_keras.layers.Layer)):
raise ValueError("The BUILDER returns an unexpected instance. The "
"`build_encoder` should returns a tf.Module, "
"tf_keras.Model or tf_keras.layers.Layer. However, "
f"we get {encoder.__class__}")
return encoder
if encoder_type == "mobilebert":
return networks.MobileBERTEncoder(
word_vocab_size=encoder_cfg.word_vocab_size,
word_embed_size=encoder_cfg.word_embed_size,
type_vocab_size=encoder_cfg.type_vocab_size,
max_sequence_length=encoder_cfg.max_sequence_length,
num_blocks=encoder_cfg.num_blocks,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_act_fn=encoder_cfg.hidden_activation,
hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
initializer_range=encoder_cfg.initializer_range,
use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
num_feedforward_networks=encoder_cfg.num_feedforward_networks,
normalization_type=encoder_cfg.normalization_type,
classifier_activation=encoder_cfg.classifier_activation,
input_mask_dtype=encoder_cfg.input_mask_dtype)
if encoder_type == "albert":
return networks.AlbertEncoder(
vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_width,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dict_outputs=True)
if encoder_type == "bigbird":
# TODO(frederickliu): Support use_gradient_checkpointing and update
# experiments to use the EncoderScaffold only.
if encoder_cfg.use_gradient_checkpointing:
return bigbird_encoder.BigBirdEncoder(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
num_rand_blocks=encoder_cfg.num_rand_blocks,
block_size=encoder_cfg.block_size,
max_position_embeddings=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
embedding_width=encoder_cfg.embedding_width,
use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing)
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
attention_cfg = dict(
num_heads=encoder_cfg.num_attention_heads,
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
max_rand_mask_length=encoder_cfg.max_position_embeddings,
num_rand_blocks=encoder_cfg.num_rand_blocks,
from_block_size=encoder_cfg.block_size,
to_block_size=encoder_cfg.block_size,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
attention_cls=layers.BigBirdAttention,
attention_cfg=attention_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
mask_cls=layers.BigBirdMasks,
mask_cfg=dict(block_size=encoder_cfg.block_size),
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "funnel":
if encoder_cfg.hidden_activation == "gelu":
activation = tf_utils.get_activation(
encoder_cfg.hidden_activation,
approximate=encoder_cfg.approx_gelu)
else:
activation = tf_utils.get_activation(encoder_cfg.hidden_activation)
return networks.FunnelTransformerEncoder(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
inner_dim=encoder_cfg.inner_dim,
inner_activation=activation,
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
pool_type=encoder_cfg.pool_type,
pool_stride=encoder_cfg.pool_stride,
unpool_length=encoder_cfg.unpool_length,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
embedding_layer=embedding_layer,
norm_first=encoder_cfg.norm_first,
share_rezero=encoder_cfg.share_rezero,
append_dense_inputs=encoder_cfg.append_dense_inputs,
transformer_cls=encoder_cfg.transformer_cls,
)
if encoder_type == "kernel":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
attention_cfg = dict(
num_heads=encoder_cfg.num_attention_heads,
key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
feature_transform=encoder_cfg.feature_transform,
num_random_features=encoder_cfg.num_random_features,
redraw=encoder_cfg.redraw,
is_short_seq=encoder_cfg.is_short_seq,
begin_kernel=encoder_cfg.begin_kernel,
scale=encoder_cfg.scale,
)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
intermediate_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
attention_cls=layers.KernelAttention,
attention_cfg=attention_cfg)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.TransformerScaffold,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
mask_cls=layers.KernelMask,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
layer_idx_as_attention_seed=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "xlnet":
return networks.XLNetBase(
vocab_size=encoder_cfg.vocab_size,
num_layers=encoder_cfg.num_layers,
hidden_size=encoder_cfg.hidden_size,
num_attention_heads=encoder_cfg.num_attention_heads,
head_size=encoder_cfg.head_size,
inner_size=encoder_cfg.inner_size,
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
attention_type=encoder_cfg.attention_type,
bi_data=encoder_cfg.bi_data,
two_stream=encoder_cfg.two_stream,
tie_attention_biases=encoder_cfg.tie_attention_biases,
memory_length=encoder_cfg.memory_length,
clamp_length=encoder_cfg.clamp_length,
reuse_length=encoder_cfg.reuse_length,
inner_activation=encoder_cfg.inner_activation,
use_cls_mask=encoder_cfg.use_cls_mask,
embedding_width=encoder_cfg.embedding_width,
initializer=tf_keras.initializers.RandomNormal(
stddev=encoder_cfg.initializer_range))
if encoder_type == "reuse":
embedding_cfg = dict(
vocab_size=encoder_cfg.vocab_size,
type_vocab_size=encoder_cfg.type_vocab_size,
hidden_size=encoder_cfg.hidden_size,
max_seq_length=encoder_cfg.max_position_embeddings,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
dropout_rate=encoder_cfg.dropout_rate)
hidden_cfg = dict(
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.intermediate_size,
inner_activation=tf_utils.get_activation(
encoder_cfg.hidden_activation),
output_dropout=encoder_cfg.dropout_rate,
attention_dropout=encoder_cfg.attention_dropout_rate,
norm_first=encoder_cfg.norm_first,
kernel_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
reuse_attention=encoder_cfg.reuse_attention,
use_relative_pe=encoder_cfg.use_relative_pe,
pe_max_seq_length=encoder_cfg.pe_max_seq_length,
max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx)
kwargs = dict(
embedding_cfg=embedding_cfg,
hidden_cls=layers.ReuseTransformer,
hidden_cfg=hidden_cfg,
num_hidden_instances=encoder_cfg.num_layers,
pooled_output_dim=encoder_cfg.hidden_size,
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
return_all_layer_outputs=False,
dict_outputs=True,
feed_layer_idx=True,
recursive=True)
return networks.EncoderScaffold(**kwargs)
if encoder_type == "query_bert":
embedding_layer = layers.FactorizedEmbedding(
vocab_size=encoder_cfg.vocab_size,
embedding_width=encoder_cfg.embedding_size,
output_dim=encoder_cfg.hidden_size,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
name="word_embeddings")
return networks.BertEncoderV2(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
return_attention_scores=encoder_cfg.return_attention_scores,
dict_outputs=True,
norm_first=encoder_cfg.norm_first)
if encoder_type == "fnet":
return networks.FNet(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
inner_dim=encoder_cfg.inner_dim,
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
output_dropout=encoder_cfg.output_dropout,
attention_dropout=encoder_cfg.attention_dropout,
max_sequence_length=encoder_cfg.max_sequence_length,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
embedding_layer=embedding_layer,
norm_first=encoder_cfg.norm_first,
use_fft=encoder_cfg.use_fft,
attention_layers=encoder_cfg.attention_layers)
if encoder_type == "sparse_mixer":
return networks.SparseMixer(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
moe_layers=encoder_cfg.moe_layers,
attention_layers=encoder_cfg.attention_layers,
num_experts=encoder_cfg.num_experts,
train_capacity_factor=encoder_cfg.train_capacity_factor,
eval_capacity_factor=encoder_cfg.eval_capacity_factor,
examples_per_group=encoder_cfg.examples_per_group,
use_fft=encoder_cfg.use_fft,
num_attention_heads=encoder_cfg.num_attention_heads,
max_sequence_length=encoder_cfg.max_sequence_length,
type_vocab_size=encoder_cfg.type_vocab_size,
inner_dim=encoder_cfg.inner_dim,
inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
output_dropout=encoder_cfg.output_dropout,
attention_dropout=encoder_cfg.attention_dropout,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_width,
norm_first=encoder_cfg.norm_first,
embedding_layer=embedding_layer)
bert_encoder_cls = networks.BertEncoder
if encoder_type == "bert_v2":
bert_encoder_cls = networks.BertEncoderV2
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return bert_encoder_cls(
vocab_size=encoder_cfg.vocab_size,
hidden_size=encoder_cfg.hidden_size,
num_layers=encoder_cfg.num_layers,
num_attention_heads=encoder_cfg.num_attention_heads,
intermediate_size=encoder_cfg.intermediate_size,
activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
dropout_rate=encoder_cfg.dropout_rate,
attention_dropout_rate=encoder_cfg.attention_dropout_rate,
max_sequence_length=encoder_cfg.max_position_embeddings,
type_vocab_size=encoder_cfg.type_vocab_size,
initializer=tf_keras.initializers.TruncatedNormal(
stddev=encoder_cfg.initializer_range),
output_range=encoder_cfg.output_range,
embedding_width=encoder_cfg.embedding_size,
embedding_layer=embedding_layer,
return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
return_attention_scores=encoder_cfg.return_attention_scores,
dict_outputs=True,
norm_first=encoder_cfg.norm_first)