Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Ranking Model configuration definition.""" | |
import dataclasses | |
from typing import List, Optional, Union | |
from official.core import config_definitions as cfg | |
from official.core import exp_factory | |
from official.modeling import hyperparams | |
class CallbacksConfig(hyperparams.Config): | |
"""Configuration for Callbacks. | |
Attributes: | |
enable_checkpoint_and_export: Whether or not to enable checkpoints as a | |
Callback. Defaults to True. | |
enable_backup_and_restore: Whether or not to add BackupAndRestore | |
callback. Defaults to True. | |
enable_tensorboard: Whether or not to enable TensorBoard as a Callback. | |
Defaults to True. | |
enable_time_history: Whether or not to enable TimeHistory Callbacks. | |
Defaults to True. | |
""" | |
enable_checkpoint_and_export: bool = True | |
enable_backup_and_restore: bool = False | |
enable_tensorboard: bool = True | |
enable_time_history: bool = True | |
class LearningRateConfig(hyperparams.Config): | |
"""Learning rate scheduler config.""" | |
learning_rate: float = 1.25 | |
warmup_steps: int = 8000 | |
decay_steps: int = 30000 | |
decay_start_steps: int = 70000 | |
decay_exp: float = 2 | |
class OptimizationConfig(hyperparams.Config): | |
"""Embedding and dense optimizer configs.""" | |
lr_config: LearningRateConfig = dataclasses.field( | |
default_factory=LearningRateConfig | |
) | |
dense_sgd_config: LearningRateConfig = dataclasses.field( | |
default_factory=lambda: LearningRateConfig(warmup_steps=0) | |
) | |
embedding_optimizer: str = 'SGD' | |
dense_optimizer: str = 'Adam' | |
class DataConfig(hyperparams.Config): | |
"""Dataset config for training and evaluation.""" | |
input_path: str = '' | |
global_batch_size: int = 0 | |
is_training: bool = True | |
dtype: str = 'float32' | |
shuffle_buffer_size: int = 10000 | |
cycle_length: int = 10 | |
sharding: bool = True | |
num_shards_per_host: int = 8 | |
class ModelConfig(hyperparams.Config): | |
"""Configuration for training. | |
Attributes: | |
num_dense_features: Number of dense features. | |
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees | |
with the order of the input data. | |
embedding_dim: An integer or a list of embedding table dimensions. | |
If it's an integer then all tables will have the same embedding dimension. | |
If it's a list then the length should match with `vocab_sizes`. | |
size_threshold: A threshold for table sizes below which a keras | |
embedding layer is used, and above which a TPU embedding layer is used. | |
If it's -1 then only keras embedding layer will be used for all tables, | |
if 0 only then only TPU embedding layer will be used. | |
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense | |
features. | |
top_mlp: The sizes of hidden layers for top MLP. | |
interaction: Interaction can be on of the following: | |
'dot', 'cross'. | |
""" | |
num_dense_features: int = 13 | |
vocab_sizes: List[int] = dataclasses.field(default_factory=list) | |
embedding_dim: Union[int, List[int]] = 8 | |
size_threshold: int = 50_000 | |
bottom_mlp: List[int] = dataclasses.field(default_factory=list) | |
top_mlp: List[int] = dataclasses.field(default_factory=list) | |
interaction: str = 'dot' | |
class Loss(hyperparams.Config): | |
"""Configuration for Loss. | |
Attributes: | |
label_smoothing: Whether or not to apply label smoothing to the | |
Binary Crossentropy loss. | |
""" | |
label_smoothing: float = 0.0 | |
class Task(hyperparams.Config): | |
"""The model config.""" | |
init_checkpoint: str = '' | |
model: ModelConfig = dataclasses.field(default_factory=ModelConfig) | |
train_data: DataConfig = dataclasses.field( | |
default_factory=lambda: DataConfig(is_training=True) | |
) | |
validation_data: DataConfig = dataclasses.field( | |
default_factory=lambda: DataConfig(is_training=False) | |
) | |
loss: Loss = dataclasses.field(default_factory=Loss) | |
use_synthetic_data: bool = False | |
class TimeHistoryConfig(hyperparams.Config): | |
"""Configuration for the TimeHistory callback. | |
Attributes: | |
log_steps: Interval of steps between logging of batch level stats. | |
""" | |
log_steps: Optional[int] = None | |
class TrainerConfig(cfg.TrainerConfig): | |
"""Configuration for training. | |
Attributes: | |
train_steps: The number of steps used to train. | |
validation_steps: The number of steps used to eval. | |
validation_interval: The Number of training steps to run between | |
evaluations. | |
callbacks: An instance of CallbacksConfig. | |
use_orbit: Whether to use orbit library with custom training loop or | |
compile/fit API. | |
enable_metrics_in_training: Whether to enable metrics during training. | |
time_history: Config of TimeHistory callback. | |
optimizer_config: An `OptimizerConfig` instance for embedding optimizer. | |
Defaults to None. | |
pipeline_sparse_and_dense_exeuction: Whether to pipeline embedding and | |
dense execution. This is a performance optimization. | |
""" | |
train_steps: int = 0 | |
# Sets validation steps to be -1 to evaluate the entire dataset. | |
validation_steps: int = -1 | |
validation_interval: int = 70000 | |
callbacks: CallbacksConfig = dataclasses.field( | |
default_factory=CallbacksConfig | |
) | |
use_orbit: bool = False | |
enable_metrics_in_training: bool = True | |
time_history: TimeHistoryConfig = dataclasses.field( | |
default_factory=lambda: TimeHistoryConfig(log_steps=5000) | |
) | |
optimizer_config: OptimizationConfig = dataclasses.field( | |
default_factory=OptimizationConfig | |
) | |
pipeline_sparse_and_dense_execution: bool = False | |
NUM_TRAIN_EXAMPLES = 4195197692 | |
NUM_EVAL_EXAMPLES = 89137318 | |
train_batch_size = 16384 | |
eval_batch_size = 16384 | |
steps_per_epoch = NUM_TRAIN_EXAMPLES // train_batch_size | |
vocab_sizes = [ | |
39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951, | |
2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295, | |
39664984, 585935, 12972, 108, 36 | |
] | |
class Config(hyperparams.Config): | |
"""Configuration to train the RankingModel. | |
By default it configures DLRM model on criteo dataset. | |
Attributes: | |
runtime: A `RuntimeConfig` instance. | |
task: `Task` instance. | |
trainer: A `TrainerConfig` instance. | |
""" | |
runtime: cfg.RuntimeConfig = dataclasses.field( | |
default_factory=cfg.RuntimeConfig | |
) | |
task: Task = dataclasses.field( | |
default_factory=lambda: Task( # pylint: disable=g-long-lambda | |
model=ModelConfig( | |
embedding_dim=8, | |
vocab_sizes=vocab_sizes, | |
bottom_mlp=[64, 32, 8], | |
top_mlp=[64, 32, 1], | |
), | |
loss=Loss(label_smoothing=0.0), | |
train_data=DataConfig( | |
is_training=True, global_batch_size=train_batch_size | |
), | |
validation_data=DataConfig( | |
is_training=False, global_batch_size=eval_batch_size | |
), | |
) | |
) | |
trainer: TrainerConfig = dataclasses.field( | |
default_factory=lambda: TrainerConfig( # pylint: disable=g-long-lambda | |
train_steps=2 * steps_per_epoch, | |
validation_interval=steps_per_epoch, | |
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size, | |
enable_metrics_in_training=True, | |
optimizer_config=OptimizationConfig(), | |
) | |
) | |
restrictions: dataclasses.InitVar[Optional[List[str]]] = None | |
def default_config() -> Config: | |
return Config( | |
runtime=cfg.RuntimeConfig(), | |
task=Task( | |
model=ModelConfig( | |
embedding_dim=8, | |
vocab_sizes=vocab_sizes, | |
bottom_mlp=[64, 32, 4], | |
top_mlp=[64, 32, 1]), | |
loss=Loss(label_smoothing=0.0), | |
train_data=DataConfig( | |
global_batch_size=train_batch_size, | |
is_training=True, | |
sharding=True), | |
validation_data=DataConfig( | |
global_batch_size=eval_batch_size, | |
is_training=False, | |
sharding=False)), | |
trainer=TrainerConfig( | |
train_steps=2 * steps_per_epoch, | |
validation_interval=steps_per_epoch, | |
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size, | |
enable_metrics_in_training=True, | |
optimizer_config=OptimizationConfig()), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None', | |
]) | |
def dlrm_criteo_tb_config() -> Config: | |
return Config( | |
runtime=cfg.RuntimeConfig(), | |
task=Task( | |
model=ModelConfig( | |
num_dense_features=13, | |
vocab_sizes=vocab_sizes, | |
bottom_mlp=[512, 256, 64], | |
embedding_dim=64, | |
top_mlp=[1024, 1024, 512, 256, 1], | |
interaction='dot'), | |
loss=Loss(label_smoothing=0.0), | |
train_data=DataConfig( | |
global_batch_size=train_batch_size, | |
is_training=True, | |
sharding=True), | |
validation_data=DataConfig( | |
global_batch_size=eval_batch_size, | |
is_training=False, | |
sharding=False)), | |
trainer=TrainerConfig( | |
train_steps=steps_per_epoch, | |
validation_interval=steps_per_epoch // 2, | |
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size, | |
enable_metrics_in_training=True, | |
optimizer_config=OptimizationConfig()), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None', | |
]) | |
def dcn_criteo_tb_config() -> Config: | |
return Config( | |
runtime=cfg.RuntimeConfig(), | |
task=Task( | |
model=ModelConfig( | |
num_dense_features=13, | |
vocab_sizes=vocab_sizes, | |
bottom_mlp=[512, 256, 64], | |
embedding_dim=64, | |
top_mlp=[1024, 1024, 512, 256, 1], | |
interaction='cross'), | |
loss=Loss(label_smoothing=0.0), | |
train_data=DataConfig( | |
global_batch_size=train_batch_size, | |
is_training=True, | |
sharding=True), | |
validation_data=DataConfig( | |
global_batch_size=eval_batch_size, | |
is_training=False, | |
sharding=False)), | |
trainer=TrainerConfig( | |
train_steps=steps_per_epoch, | |
validation_interval=steps_per_epoch // 2, | |
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size, | |
enable_metrics_in_training=True, | |
optimizer_config=OptimizationConfig()), | |
restrictions=[ | |
'task.train_data.is_training != None', | |
'task.validation_data.is_training != None', | |
]) | |