Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# pylint: disable=g-doc-return-or-yield,line-too-long | |
"""WMT translation configurations.""" | |
from official.core import config_definitions as cfg | |
from official.core import exp_factory | |
from official.modeling import optimization | |
from official.nlp.data import wmt_dataloader | |
from official.nlp.tasks import translation | |
def wmt_transformer_large() -> cfg.ExperimentConfig: | |
"""WMT Transformer Large. | |
Please refer to | |
tensorflow_models/official/nlp/data/train_sentencepiece.py | |
to generate sentencepiece_model | |
and pass | |
--params_override=task.sentencepiece_model_path='YOUR_PATH' | |
to the train script. | |
""" | |
learning_rate = 2.0 | |
hidden_size = 1024 | |
learning_rate *= (hidden_size**-0.5) | |
warmup_steps = 16000 | |
train_steps = 300000 | |
token_batch_size = 24576 | |
encdecoder = translation.EncDecoder( | |
num_attention_heads=16, intermediate_size=hidden_size * 4) | |
config = cfg.ExperimentConfig( | |
runtime=cfg.RuntimeConfig(enable_xla=True), | |
task=translation.TranslationConfig( | |
model=translation.ModelConfig( | |
encoder=encdecoder, | |
decoder=encdecoder, | |
embedding_width=hidden_size, | |
padded_decode=True, | |
decode_max_length=100), | |
train_data=wmt_dataloader.WMTDataConfig( | |
tfds_name='wmt14_translate/de-en', | |
tfds_split='train', | |
src_lang='en', | |
tgt_lang='de', | |
is_training=True, | |
global_batch_size=token_batch_size, | |
static_batch=True, | |
max_seq_length=64 | |
), | |
validation_data=wmt_dataloader.WMTDataConfig( | |
tfds_name='wmt14_translate/de-en', | |
tfds_split='test', | |
src_lang='en', | |
tgt_lang='de', | |
is_training=False, | |
global_batch_size=32, | |
static_batch=True, | |
max_seq_length=100, | |
), | |
sentencepiece_model_path=None, | |
), | |
trainer=cfg.TrainerConfig( | |
train_steps=train_steps, | |
validation_steps=-1, | |
steps_per_loop=1000, | |
summary_interval=1000, | |
checkpoint_interval=5000, | |
validation_interval=5000, | |
max_to_keep=1, | |
optimizer_config=optimization.OptimizationConfig({ | |
'optimizer': { | |
'type': 'adam', | |
'adam': { | |
'beta_2': 0.997, | |
'epsilon': 1e-9, | |
}, | |
}, | |
'learning_rate': { | |
'type': 'power', | |
'power': { | |
'initial_learning_rate': learning_rate, | |
'power': -0.5, | |
} | |
}, | |
'warmup': { | |
'type': 'linear', | |
'linear': { | |
'warmup_steps': warmup_steps, | |
'warmup_learning_rate': 0.0 | |
} | |
} | |
})), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.sentencepiece_model_path != None', | |
]) | |
return config | |