|
from __gin__ import dynamic_registration |
|
import __main__ as train_script |
|
import seqio |
|
from t5x import adafactor |
|
from t5x.examples.t5 import network |
|
from t5x import gin_utils |
|
from t5x import models |
|
from t5x import partitioning |
|
from t5x import trainer |
|
from t5x import utils |
|
import tasks |
|
|
|
|
|
|
|
BATCH_SIZE = 256 |
|
DROPOUT_RATE = 0.0 |
|
LABEL_SMOOTHING = 0.0 |
|
LOSS_NORMALIZING_FACTOR = None |
|
MIXTURE_OR_TASK_MODULE = None |
|
MIXTURE_OR_TASK_NAME = 'pretrain_finnish' |
|
MODEL = @models.EncoderDecoderModel() |
|
MODEL_DIR = '/researchdisk/t5x-mini-nl8-finnish' |
|
OPTIMIZER = @adafactor.Adafactor() |
|
RANDOM_SEED = None |
|
SHUFFLE_TRAIN_EXAMPLES = True |
|
TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 114} |
|
TRAIN_STEPS = 500000 |
|
USE_CACHED_TASKS = False |
|
USE_HARDWARE_RNG = False |
|
VOCABULARY = @seqio.SentencePieceVocabulary() |
|
Z_LOSS = 0.0001 |
|
|
|
# Parameters for adafactor.Adafactor: |
|
# ============================================================================== |
|
adafactor.Adafactor.decay_rate = 0.8 |
|
adafactor.Adafactor.logical_factor_rules = \ |
|
@adafactor.standard_logical_factor_rules() |
|
adafactor.Adafactor.step_offset = 0 |
|
|
|
# Parameters for utils.CheckpointConfig: |
|
# ============================================================================== |
|
utils.CheckpointConfig.restore = @utils.RestoreCheckpointConfig() |
|
utils.CheckpointConfig.save = @utils.SaveCheckpointConfig() |
|
|
|
# Parameters for utils.create_learning_rate_scheduler: |
|
# ============================================================================== |
|
utils.create_learning_rate_scheduler.base_learning_rate = 1.0 |
|
utils.create_learning_rate_scheduler.factors = 'constant * rsqrt_decay' |
|
utils.create_learning_rate_scheduler.warmup_steps = 10000 |
|
|
|
# Parameters for train/utils.DatasetConfig: |
|
# ============================================================================== |
|
train/utils.DatasetConfig.batch_size = %BATCH_SIZE |
|
train/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
|
train/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE |
|
train/utils.DatasetConfig.pack = True |
|
train/utils.DatasetConfig.seed = None |
|
train/utils.DatasetConfig.shuffle = %SHUFFLE_TRAIN_EXAMPLES |
|
train/utils.DatasetConfig.split = 'train' |
|
train/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS |
|
train/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS |
|
|
|
# Parameters for train_eval/utils.DatasetConfig: |
|
# ============================================================================== |
|
train_eval/utils.DatasetConfig.batch_size = %BATCH_SIZE |
|
train_eval/utils.DatasetConfig.mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
|
train_eval/utils.DatasetConfig.module = %MIXTURE_OR_TASK_MODULE |
|
train_eval/utils.DatasetConfig.pack = True |
|
train_eval/utils.DatasetConfig.seed = 42 |
|
train_eval/utils.DatasetConfig.shuffle = False |
|
train_eval/utils.DatasetConfig.split = 'validation' |
|
train_eval/utils.DatasetConfig.task_feature_lengths = %TASK_FEATURE_LENGTHS |
|
train_eval/utils.DatasetConfig.use_cached = %USE_CACHED_TASKS |
|
|
|
# Parameters for models.EncoderDecoderModel: |
|
# ============================================================================== |
|
models.EncoderDecoderModel.input_vocabulary = %VOCABULARY |
|
models.EncoderDecoderModel.label_smoothing = %LABEL_SMOOTHING |
|
models.EncoderDecoderModel.loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR |
|
models.EncoderDecoderModel.module = @network.Transformer() |
|
models.EncoderDecoderModel.optimizer_def = %OPTIMIZER |
|
models.EncoderDecoderModel.output_vocabulary = %VOCABULARY |
|
models.EncoderDecoderModel.z_loss = %Z_LOSS |
|
|
|
# Parameters for partitioning.PjitPartitioner: |
|
# ============================================================================== |
|
partitioning.PjitPartitioner.logical_axis_rules = \ |
|
@partitioning.standard_logical_axis_rules() |
|
partitioning.PjitPartitioner.model_parallel_submesh = None |
|
partitioning.PjitPartitioner.num_partitions = 1 |
|
|
|
# Parameters for utils.RestoreCheckpointConfig: |
|
# ============================================================================== |
|
utils.RestoreCheckpointConfig.path = [] |
|
|
|
# Parameters for utils.SaveCheckpointConfig: |
|
# ============================================================================== |
|
utils.SaveCheckpointConfig.dtype = 'float32' |
|
utils.SaveCheckpointConfig.keep = 5 |
|
utils.SaveCheckpointConfig.period = 20000 |
|
utils.SaveCheckpointConfig.save_dataset = False |
|
|
|
# Parameters for seqio.SentencePieceVocabulary: |
|
# ============================================================================== |
|
seqio.SentencePieceVocabulary.sentencepiece_model_file = 'spiece.model' |
|
|
|
# Parameters for network.T5Config: |
|
# ============================================================================== |
|
network.T5Config.dropout_rate = %DROPOUT_RATE |
|
network.T5Config.dtype = 'bfloat16' |
|
network.T5Config.emb_dim = 384 |
|
network.T5Config.head_dim = 64 |
|
network.T5Config.logits_via_embedding = False |
|
network.T5Config.mlp_activations = ('gelu', 'linear') |
|
network.T5Config.mlp_dim = 1536 |
|
network.T5Config.num_decoder_layers = 8 |
|
network.T5Config.num_encoder_layers = 8 |
|
network.T5Config.num_heads = 8 |
|
network.T5Config.vocab_size = 32128 |
|
|
|
# Parameters for train_script.train: |
|
# ============================================================================== |
|
train_script.train.checkpoint_cfg = @utils.CheckpointConfig() |
|
train_script.train.eval_period = 10000 |
|
train_script.train.eval_steps = 20 |
|
train_script.train.infer_eval_dataset_cfg = None |
|
train_script.train.model = %MODEL |
|
train_script.train.model_dir = %MODEL_DIR |
|
train_script.train.partitioner = @partitioning.PjitPartitioner() |
|
train_script.train.random_seed = %RANDOM_SEED |
|
train_script.train.summarize_config_fn = @gin_utils.summarize_gin_config |
|
train_script.train.total_steps = %TRAIN_STEPS |
|
train_script.train.train_dataset_cfg = @train/utils.DatasetConfig() |
|
train_script.train.train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() |
|
train_script.train.trainer_cls = @trainer.Trainer |
|
train_script.train.use_hardware_rng = %USE_HARDWARE_RNG |
|
|
|
# Parameters for trainer.Trainer: |
|
# ============================================================================== |
|
trainer.Trainer.learning_rate_fn = @utils.create_learning_rate_scheduler() |
|
trainer.Trainer.num_microbatches = None |
|
|
|
# Parameters for network.Transformer: |
|
# ============================================================================== |
|
network.Transformer.config = @network.T5Config() |
|
|