ASL-MoViNet-T5-translator / official /nlp /continuous_finetune_lib_test.py
deanna-emery's picture
updates
93528c6
raw
history blame
3.19 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import tensorflow as tf, tf_keras
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.nlp import continuous_finetune_lib
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
def testContinuousFinetune(self):
pretrain_steps = 1
src_model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='continuous_train_and_eval',
model_dir=self._model_dir,
params_override={
'task': {
'init_checkpoint': src_model_dir,
},
'trainer': {
'continuous_eval_timeout': 1,
'steps_per_loop': 1,
'train_steps': 1,
'validation_steps': 1,
'best_checkpoint_export_subdir': 'best_ckpt',
'best_checkpoint_eval_metric': 'acc',
'optimizer_config': {
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
}
}
})
with flagsaver.flagsaver(**flags_dict):
# Train and save some checkpoints.
params = train_utils.parse_configuration(flags.FLAGS)
distribution_strategy = tf.distribute.get_strategy()
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=src_model_dir)
_ = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train',
params=params,
model_dir=src_model_dir)
params = train_utils.parse_configuration(FLAGS)
eval_metrics = continuous_finetune_lib.run_continuous_finetune(
FLAGS.mode,
params,
FLAGS.model_dir,
run_post_eval=True,
pretrain_steps=pretrain_steps)
self.assertIn('best_acc', eval_metrics)
self.assertFalse(
tf.io.gfile.exists(os.path.join(FLAGS.model_dir, 'checkpoint')))
if __name__ == '__main__':
tf.test.main()