deanna-emery's picture
updates
93528c6
raw
history blame
11.4 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ranking Model configuration definition."""
import dataclasses
from typing import List, Optional, Union
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
@dataclasses.dataclass
class CallbacksConfig(hyperparams.Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_backup_and_restore: bool = False
enable_tensorboard: bool = True
enable_time_history: bool = True
@dataclasses.dataclass
class LearningRateConfig(hyperparams.Config):
"""Learning rate scheduler config."""
learning_rate: float = 1.25
warmup_steps: int = 8000
decay_steps: int = 30000
decay_start_steps: int = 70000
decay_exp: float = 2
@dataclasses.dataclass
class OptimizationConfig(hyperparams.Config):
"""Embedding and dense optimizer configs."""
lr_config: LearningRateConfig = dataclasses.field(
default_factory=LearningRateConfig
)
dense_sgd_config: LearningRateConfig = dataclasses.field(
default_factory=lambda: LearningRateConfig(warmup_steps=0)
)
embedding_optimizer: str = 'SGD'
dense_optimizer: str = 'Adam'
@dataclasses.dataclass
class DataConfig(hyperparams.Config):
"""Dataset config for training and evaluation."""
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
sharding: bool = True
num_shards_per_host: int = 8
@dataclasses.dataclass
class ModelConfig(hyperparams.Config):
"""Configuration for training.
Attributes:
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
top_mlp: The sizes of hidden layers for top MLP.
interaction: Interaction can be on of the following:
'dot', 'cross'.
"""
num_dense_features: int = 13
vocab_sizes: List[int] = dataclasses.field(default_factory=list)
embedding_dim: Union[int, List[int]] = 8
size_threshold: int = 50_000
bottom_mlp: List[int] = dataclasses.field(default_factory=list)
top_mlp: List[int] = dataclasses.field(default_factory=list)
interaction: str = 'dot'
@dataclasses.dataclass
class Loss(hyperparams.Config):
"""Configuration for Loss.
Attributes:
label_smoothing: Whether or not to apply label smoothing to the
Binary Crossentropy loss.
"""
label_smoothing: float = 0.0
@dataclasses.dataclass
class Task(hyperparams.Config):
"""The model config."""
init_checkpoint: str = ''
model: ModelConfig = dataclasses.field(default_factory=ModelConfig)
train_data: DataConfig = dataclasses.field(
default_factory=lambda: DataConfig(is_training=True)
)
validation_data: DataConfig = dataclasses.field(
default_factory=lambda: DataConfig(is_training=False)
)
loss: Loss = dataclasses.field(default_factory=Loss)
use_synthetic_data: bool = False
@dataclasses.dataclass
class TimeHistoryConfig(hyperparams.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: Optional[int] = None
@dataclasses.dataclass
class TrainerConfig(cfg.TrainerConfig):
"""Configuration for training.
Attributes:
train_steps: The number of steps used to train.
validation_steps: The number of steps used to eval.
validation_interval: The Number of training steps to run between
evaluations.
callbacks: An instance of CallbacksConfig.
use_orbit: Whether to use orbit library with custom training loop or
compile/fit API.
enable_metrics_in_training: Whether to enable metrics during training.
time_history: Config of TimeHistory callback.
optimizer_config: An `OptimizerConfig` instance for embedding optimizer.
Defaults to None.
pipeline_sparse_and_dense_exeuction: Whether to pipeline embedding and
dense execution. This is a performance optimization.
"""
train_steps: int = 0
# Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps: int = -1
validation_interval: int = 70000
callbacks: CallbacksConfig = dataclasses.field(
default_factory=CallbacksConfig
)
use_orbit: bool = False
enable_metrics_in_training: bool = True
time_history: TimeHistoryConfig = dataclasses.field(
default_factory=lambda: TimeHistoryConfig(log_steps=5000)
)
optimizer_config: OptimizationConfig = dataclasses.field(
default_factory=OptimizationConfig
)
pipeline_sparse_and_dense_execution: bool = False
NUM_TRAIN_EXAMPLES = 4195197692
NUM_EVAL_EXAMPLES = 89137318
train_batch_size = 16384
eval_batch_size = 16384
steps_per_epoch = NUM_TRAIN_EXAMPLES // train_batch_size
vocab_sizes = [
39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951,
2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295,
39664984, 585935, 12972, 108, 36
]
@dataclasses.dataclass
class Config(hyperparams.Config):
"""Configuration to train the RankingModel.
By default it configures DLRM model on criteo dataset.
Attributes:
runtime: A `RuntimeConfig` instance.
task: `Task` instance.
trainer: A `TrainerConfig` instance.
"""
runtime: cfg.RuntimeConfig = dataclasses.field(
default_factory=cfg.RuntimeConfig
)
task: Task = dataclasses.field(
default_factory=lambda: Task( # pylint: disable=g-long-lambda
model=ModelConfig(
embedding_dim=8,
vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 8],
top_mlp=[64, 32, 1],
),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
is_training=True, global_batch_size=train_batch_size
),
validation_data=DataConfig(
is_training=False, global_batch_size=eval_batch_size
),
)
)
trainer: TrainerConfig = dataclasses.field(
default_factory=lambda: TrainerConfig( # pylint: disable=g-long-lambda
train_steps=2 * steps_per_epoch,
validation_interval=steps_per_epoch,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig(),
)
)
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
def default_config() -> Config:
return Config(
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
embedding_dim=8,
vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
global_batch_size=train_batch_size,
is_training=True,
sharding=True),
validation_data=DataConfig(
global_batch_size=eval_batch_size,
is_training=False,
sharding=False)),
trainer=TrainerConfig(
train_steps=2 * steps_per_epoch,
validation_interval=steps_per_epoch,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig()),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
@exp_factory.register_config_factory('dlrm_criteo')
def dlrm_criteo_tb_config() -> Config:
return Config(
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
num_dense_features=13,
vocab_sizes=vocab_sizes,
bottom_mlp=[512, 256, 64],
embedding_dim=64,
top_mlp=[1024, 1024, 512, 256, 1],
interaction='dot'),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
global_batch_size=train_batch_size,
is_training=True,
sharding=True),
validation_data=DataConfig(
global_batch_size=eval_batch_size,
is_training=False,
sharding=False)),
trainer=TrainerConfig(
train_steps=steps_per_epoch,
validation_interval=steps_per_epoch // 2,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig()),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
@exp_factory.register_config_factory('dcn_criteo')
def dcn_criteo_tb_config() -> Config:
return Config(
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
num_dense_features=13,
vocab_sizes=vocab_sizes,
bottom_mlp=[512, 256, 64],
embedding_dim=64,
top_mlp=[1024, 1024, 512, 256, 1],
interaction='cross'),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
global_batch_size=train_batch_size,
is_training=True,
sharding=True),
validation_data=DataConfig(
global_batch_size=eval_batch_size,
is_training=False,
sharding=False)),
trainer=TrainerConfig(
train_steps=steps_per_epoch,
validation_interval=steps_per_epoch // 2,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig()),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])