deanna-emery's picture
updates
93528c6
raw
history blame
4.89 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ranking model and associated functionality."""
import json
import os
from absl import flags
from absl.testing import parameterized
import tensorflow as tf, tf_keras
from official.recommendation.ranking import common
from official.recommendation.ranking import train
FLAGS = flags.FLAGS
def _get_params_override(vocab_sizes,
interaction='dot',
use_orbit=True,
strategy='mirrored'):
# Update `data_dir` if `synthetic_data=False`.
data_dir = ''
return json.dumps({
'runtime': {
'distribution_strategy': strategy,
},
'task': {
'model': {
'vocab_sizes': vocab_sizes,
'embedding_dim': [8] * len(vocab_sizes),
'bottom_mlp': [64, 32, 8],
'interaction': interaction,
},
'train_data': {
'input_path': os.path.join(data_dir, 'train/*'),
'global_batch_size': 16,
},
'validation_data': {
'input_path': os.path.join(data_dir, 'eval/*'),
'global_batch_size': 16,
},
'use_synthetic_data': True,
},
'trainer': {
'use_orbit': use_orbit,
'validation_interval': 20,
'validation_steps': 20,
'train_steps': 40,
},
})
class TrainTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
self._temp_dir = self.get_temp_dir()
self._model_dir = os.path.join(self._temp_dir, 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
FLAGS.model_dir = self._model_dir
FLAGS.tpu = ''
def tearDown(self):
tf.io.gfile.rmtree(self._model_dir)
super().tearDown()
@parameterized.named_parameters(
('DlrmOneDeviceCTL', 'one_device', 'dot', True),
('DlrmOneDevice', 'one_device', 'dot', False),
('DcnOneDeviceCTL', 'one_device', 'cross', True),
('DcnOneDevice', 'one_device', 'cross', False),
('DlrmTPUCTL', 'tpu', 'dot', True),
('DlrmTPU', 'tpu', 'dot', False),
('DcnTPUCTL', 'tpu', 'cross', True),
('DcnTPU', 'tpu', 'cross', False),
('DlrmMirroredCTL', 'Mirrored', 'dot', True),
('DlrmMirrored', 'Mirrored', 'dot', False),
('DcnMirroredCTL', 'Mirrored', 'cross', True),
('DcnMirrored', 'Mirrored', 'cross', False),
)
def testTrainEval(self, strategy, interaction, use_orbit=True):
# Set up simple trainer with synthetic data.
# By default the mode must be `train_and_eval`.
self.assertEqual(FLAGS.mode, 'train_and_eval')
vocab_sizes = [40, 12, 11, 13]
FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes,
interaction=interaction,
use_orbit=use_orbit,
strategy=strategy)
train.main('unused_args')
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))
@parameterized.named_parameters(
('DlrmTPUCTL', 'tpu', 'dot', True),
('DlrmTPU', 'tpu', 'dot', False),
('DcnTPUCTL', 'tpu', 'cross', True),
('DcnTPU', 'tpu', 'cross', False),
('DlrmMirroredCTL', 'Mirrored', 'dot', True),
('DlrmMirrored', 'Mirrored', 'dot', False),
('DcnMirroredCTL', 'Mirrored', 'cross', True),
('DcnMirrored', 'Mirrored', 'cross', False),
)
def testTrainThenEval(self, strategy, interaction, use_orbit=True):
# Set up simple trainer with synthetic data.
vocab_sizes = [40, 12, 11, 13]
FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes,
interaction=interaction,
use_orbit=use_orbit,
strategy=strategy)
default_mode = FLAGS.mode
# Training.
FLAGS.mode = 'train'
train.main('unused_args')
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))
# Evaluation.
FLAGS.mode = 'eval'
train.main('unused_args')
FLAGS.mode = default_mode
if __name__ == '__main__':
common.define_flags()
tf.test.main()