ASL-MoViNet-T5-translator / official /nlp /configs /wmt_transformer_experiments.py
deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-doc-return-or-yield,line-too-long
"""WMT translation configurations."""
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import optimization
from official.nlp.data import wmt_dataloader
from official.nlp.tasks import translation
@exp_factory.register_config_factory('wmt_transformer/large')
def wmt_transformer_large() -> cfg.ExperimentConfig:
"""WMT Transformer Large.
Please refer to
tensorflow_models/official/nlp/data/train_sentencepiece.py
to generate sentencepiece_model
and pass
--params_override=task.sentencepiece_model_path='YOUR_PATH'
to the train script.
"""
learning_rate = 2.0
hidden_size = 1024
learning_rate *= (hidden_size**-0.5)
warmup_steps = 16000
train_steps = 300000
token_batch_size = 24576
encdecoder = translation.EncDecoder(
num_attention_heads=16, intermediate_size=hidden_size * 4)
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=translation.TranslationConfig(
model=translation.ModelConfig(
encoder=encdecoder,
decoder=encdecoder,
embedding_width=hidden_size,
padded_decode=True,
decode_max_length=100),
train_data=wmt_dataloader.WMTDataConfig(
tfds_name='wmt14_translate/de-en',
tfds_split='train',
src_lang='en',
tgt_lang='de',
is_training=True,
global_batch_size=token_batch_size,
static_batch=True,
max_seq_length=64
),
validation_data=wmt_dataloader.WMTDataConfig(
tfds_name='wmt14_translate/de-en',
tfds_split='test',
src_lang='en',
tgt_lang='de',
is_training=False,
global_batch_size=32,
static_batch=True,
max_seq_length=100,
),
sentencepiece_model_path=None,
),
trainer=cfg.TrainerConfig(
train_steps=train_steps,
validation_steps=-1,
steps_per_loop=1000,
summary_interval=1000,
checkpoint_interval=5000,
validation_interval=5000,
max_to_keep=1,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adam',
'adam': {
'beta_2': 0.997,
'epsilon': 1e-9,
},
},
'learning_rate': {
'type': 'power',
'power': {
'initial_learning_rate': learning_rate,
'power': -0.5,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': warmup_steps,
'warmup_learning_rate': 0.0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.sentencepiece_model_path != None',
])
return config