deanna-emery's picture
updates
93528c6
raw
history blame
4.53 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import os
from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import actions
from official.modeling import optimization
class TestModel(tf_keras.Model):
def __init__(self):
super().__init__()
self.value = tf.Variable(0.0)
self.dense = tf_keras.layers.Dense(2)
_ = self.dense(tf.zeros((2, 2), tf.float32))
def call(self, x, training=None):
return self.value + x
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy,
],))
def test_ema_checkpointing(self, distribution):
with distribution.scope():
directory = self.create_tempdir()
model = TestModel()
optimizer = tf_keras.optimizers.SGD()
optimizer = optimization.ExponentialMovingAverage(
optimizer, trainable_weights_only=False)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer.shadow_copy(model)
checkpoint = tf.train.Checkpoint(model=model)
# Changes model.value to 3, average value is still 0.
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(0.), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(
tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(0.), 0)
# Raises an error for a normal optimizer.
with self.assertRaisesRegex(ValueError,
'Optimizer has to be instance of.*'):
_ = actions.EMACheckpointing(directory, tf_keras.optimizers.SGD(),
checkpoint)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_recovery_condition(self, distribution):
with distribution.scope():
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': 0.6}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
strategy_combinations.one_device_strategy,
],))
def test_pruning(self, distribution):
with distribution.scope():
directory = self.get_temp_dir()
model = TestModel()
optimizer = tf_keras.optimizers.SGD()
pruning = actions.PruningAction(directory, model, optimizer)
pruning({})
if __name__ == '__main__':
tf.test.main()