deanna-emery's picture
updates
5672777
raw
history blame
31.8 kB
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.controller."""
import os
from absl import logging
from absl.testing import parameterized
import numpy as np
from orbit import controller
from orbit import runner
from orbit import standard_runner
import orbit.utils
import tensorflow as tf, tf_keras
def create_model():
x = tf_keras.layers.Input(shape=(3,), name="input")
y = tf_keras.layers.Dense(4, name="dense")(x)
model = tf_keras.Model(x, y)
return model
def summaries_with_matching_keyword(keyword, summary_dir):
"""Returns summary protos matching given keyword from event file."""
matches = []
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*"))
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
if event.summary is not None:
for value in event.summary.value:
if keyword in value.tag:
matches.append(event.summary)
return matches
def dataset_fn(ctx):
del ctx
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.ones((10, 4), dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10, drop_remainder=True)
return dataset
class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self, return_numpy=False):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf_keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf_keras.metrics.Mean("train_loss", dtype=tf.float32)
self.eval_loss = tf_keras.metrics.Mean("eval_loss", dtype=tf.float32)
self.return_numpy = return_numpy
train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
standard_runner.StandardTrainer.__init__(self, train_dataset)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def train_loop_end(self):
train_loss = self.train_loss.result()
return {
"loss": train_loss.numpy() if self.return_numpy else train_loss,
}
def build_eval_dataset(self):
return self.strategy.distribute_datasets_from_function(dataset_fn)
def eval_begin(self):
self.eval_loss.reset_states()
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
self.eval_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
def eval_end(self):
eval_loss = self.eval_loss.result()
return {
"eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,
}
class TestEvaluator(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
standard_runner.StandardEvaluator.__init__(self, eval_dataset)
def eval_reduce(self, state, output):
state.append(output)
return state
def eval_begin(self):
return []
def eval_step(self, iterator):
def _replicated_step(inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
return loss
per_replica_losses = self.strategy.run(
_replicated_step, args=(next(iterator),))
mean_loss = self.strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return mean_loss
def eval_end(self, outputs):
return {
"eval_loss": tf.reduce_mean(outputs),
}
class TestEvaluatorNoOutput(runner.AbstractEvaluator):
def evaluate(self, num_steps):
pass
class TestEvaluatorWithNestedSummary(standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for the test model."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
dataset2 = self.strategy.distribute_datasets_from_function(dataset_fn)
self.loss = tf_keras.metrics.Mean("loss", dtype=tf.float32)
self.accuracy = tf_keras.metrics.CategoricalAccuracy(
"accuracy", dtype=tf.float32)
self.loss2 = tf_keras.metrics.Mean("loss", dtype=tf.float32)
self.accuracy2 = tf_keras.metrics.CategoricalAccuracy(
"accuracy", dtype=tf.float32)
standard_runner.StandardEvaluator.__init__(
self, eval_dataset={
"dataset": dataset,
"dataset2": dataset2
})
def eval_step(self, iterator):
def _replicated_step(loss, accuracy, inputs):
"""Replicated evaluation step."""
inputs, targets = inputs
outputs = self.model(inputs)
loss.update_state(tf_keras.losses.MSE(targets, outputs))
accuracy.update_state(targets, outputs)
self.strategy.run(
lambda inputs: _replicated_step(self.loss, self.accuracy, inputs),
args=(next(iterator["dataset"]),))
self.strategy.run(
lambda inputs: _replicated_step(self.loss2, self.accuracy2, inputs),
args=(next(iterator["dataset2"]),))
def eval_end(self):
return {
"dataset": {
"loss": self.loss.result(),
"accuracy": self.accuracy.result()
},
"dataset2": {
"loss": self.loss2.result(),
"accuracy": self.accuracy2.result()
},
}
class TestTrainerWithSummaries(standard_runner.StandardTrainer):
"""A Trainer model with summaries for testing purposes."""
def __init__(self):
self.strategy = tf.distribute.get_strategy()
self.model = create_model()
self.optimizer = tf_keras.optimizers.RMSprop(learning_rate=0.1)
self.global_step = self.optimizer.iterations
self.train_loss = tf_keras.metrics.Mean("train_loss", dtype=tf.float32)
train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
standard_runner.StandardTrainer.__init__(
self,
train_dataset,
options=standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True))
def build_train_dataset(self):
return self.strategy.distribute_datasets_from_function(dataset_fn)
def train_step(self, iterator):
def _replicated_step(inputs):
"""Replicated training step."""
inputs, targets = inputs
with tf.GradientTape() as tape:
outputs = self.model(inputs)
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
tf.summary.scalar("loss", loss)
grads = tape.gradient(loss, self.model.variables)
self.optimizer.apply_gradients(zip(grads, self.model.variables))
self.train_loss.update_state(loss)
self.strategy.run(_replicated_step, args=(next(iterator),))
class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self):
test_runner = TestRunner()
# No checkpoint manager and no strategy.
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
# No checkpoint, so global step starts from 0.
test_runner.global_step.assign(0)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
self.assertTrue(controller._orbit_api_gauge.get_cell().value())
def test_no_checkpoint_and_summaries(self):
test_runner = TestRunner()
# No checkpoint + summary directories.
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
self.assertTrue(controller._orbit_api_gauge.get_cell().value())
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_has_checkpoint_no_summaries(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
self.assertTrue(controller._orbit_api_gauge.get_cell().value())
# No summaries are saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_has_checkpoint_eval_summary_only(
self, enable_async_checkpoint_saving
):
test_runner = TestRunner()
# Has checkpoint, but no summary directories.
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertEqual(test_runner.global_step, 10)
# Training summaries are not saved.
self.assertEmpty(tf.io.gfile.glob(
os.path.join(checkpoint_manager.directory, "events.*")))
# Evaluation summaries are saved.
self.assertNotEmpty(tf.io.gfile.glob(
os.path.join(self.model_dir, "summaries/eval/events.*")))
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_restore_from_most_recent_checkpoint(
self, enable_async_checkpoint_saving
):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=5)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
steps_per_loop=5)
test_controller.train(20)
self.assertLen(checkpoint_manager.checkpoints, 4)
restored_path = test_controller.restore_checkpoint()
self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])
@parameterized.named_parameters(
("return_numpy_sync_checkpoint_saving", True, False),
("return_numpy_async_checkpoint_saving", True, True),
("return_tensor_sync_checkpoint_saving", False, False),
("return_tensor_async_checkpoint_saving", False, True),
)
def test_train_and_evaluate(
self, return_numpy, enable_async_checkpoint_saving
):
test_runner = TestRunner(return_numpy=return_numpy)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_train_only(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_evaluate_only(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
eval_results = test_controller.evaluate(steps=2)
# Only eval summaries are written
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
self.assertIn("eval_loss", eval_results)
# Tests continuous eval with timeout and timeout_fn.
done_file = os.path.join(self.model_dir, "summaries/eval/Done")
def timeout_fn():
with tf.io.gfile.GFile(done_file, "w") as f:
f.write("DONE")
return True
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate_continuously(
timeout=1, timeout_fn=timeout_fn, steps=2)
self.assertNotEmpty(tf.io.gfile.glob(done_file))
def test_no_eval_steps(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(model=test_runner.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
evaluator=test_runner,
global_step=test_runner.global_step,
checkpoint_manager=checkpoint_manager)
test_controller.evaluate()
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_already_trained_model(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
test_runner.global_step.assign(10)
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving)
# `global_step` is already `train_steps`.
test_controller.train(steps=10)
def test_summaries_inside_train_fn(self):
test_runner = TestTrainerWithSummaries()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
summary_interval=2,
checkpoint_manager=checkpoint_manager
)
test_controller.train(steps=10)
# Checkpoints are saved.
self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*")))
# Only train summaries are written.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries/train")))
self.assertFalse(
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval")))
def test_train_and_evaluate_with_same_summary_dir(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
# Loss and accuracy values should be written into summaries.
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "summaries")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries")))
def test_early_stop_on_eval_loss(self):
test_runner = TestRunner()
class EarlyStopController(controller.Controller):
"""A subclass of Controller that supports early stopping."""
def train_and_evaluate(self,
train_steps: int = None,
eval_steps: int = None,
eval_interval: int = None):
while self.global_step.numpy() < train_steps:
interval = min(train_steps - self.global_step.numpy(), eval_interval)
num_steps = self.global_step.numpy() + interval
self.train(steps=num_steps, checkpoint_at_completion=False)
self._sync_on_async_checkpointing()
self.evaluate(steps=eval_steps)
# Early stop condition.
if test_runner.eval_loss.result() < 0.1:
logging.info(
"Training early stopped as eval_loss %s is less than 0.1",
test_runner.eval_loss.result())
return
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
test_controller = EarlyStopController(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2,
checkpoint_manager=checkpoint_manager)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=6, eval_interval=2)
self.assertLess(test_runner.global_step, 10)
def test_evaluate_with_loss_output(self):
test_evaluator = TestEvaluator()
checkpoint = tf.train.Checkpoint(model=test_evaluator.model)
checkpoint.save(os.path.join(self.model_dir, "ckpt"))
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.model_dir, max_to_keep=None)
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.evaluate(steps=5)
# Only eval summaries are written
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
def test_evaluate_with_no_output(self):
test_controller = controller.Controller(
evaluator=TestEvaluatorNoOutput(),
global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
self.assertSameElements(["steps_per_second"],
test_controller.evaluate(steps=5).keys())
def test_train_and_evaluate_reset_datasets(self):
test_runner = TestRunner()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=2)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
train_dataset = (
test_runner.strategy.distribute_datasets_from_function(dataset_fn))
eval_dataset = (
test_runner.strategy.distribute_datasets_from_function(dataset_fn))
test_runner.train_dataset = train_dataset
test_runner.eval_dataset = eval_dataset
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
@parameterized.named_parameters(
("_sync_checkpoint_saving", False),
("_async_checkpoint_saving", True)
)
def test_eval_and_checkpoint_interval(self, enable_async_checkpoint_saving):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=5)
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
global_step=test_runner.global_step,
steps_per_loop=10,
checkpoint_manager=checkpoint_manager,
enable_async_checkpointing=enable_async_checkpoint_saving,
summary_dir=self.model_dir)
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=5)
# Expect 3 checkpoints to be saved at step: 5, 10.
self.assertLen(
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2)
# Expect evaluation is performed 2 times at step: 5, 10.
self.assertLen(
summaries_with_matching_keyword("eval_loss", self.model_dir), 2)
@parameterized.named_parameters(("DefaultSummary", False),
("InjectSummary", True))
def test_evaluate_with_nested_summaries(self, inject_summary_manager):
test_evaluator = TestEvaluatorWithNestedSummary()
if inject_summary_manager:
summary_manager = orbit.utils.SummaryManager(
self.model_dir,
tf.summary.scalar,
global_step=tf.Variable(0, dtype=tf.int64))
else:
summary_manager = None
test_controller = controller.Controller(
evaluator=test_evaluator,
global_step=tf.Variable(0, dtype=tf.int64),
eval_summary_dir=self.model_dir,
summary_manager=summary_manager)
test_controller.evaluate(steps=5)
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"accuracy", os.path.join(self.model_dir, "dataset")))
self.assertNotEmpty(
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset2")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"loss", os.path.join(self.model_dir, "dataset2")))
self.assertNotEmpty(
summaries_with_matching_keyword(
"accuracy", os.path.join(self.model_dir, "dataset2")))
def test_actions(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
class OutputRecorderAction:
"""Simple `Action` that just saves the outputs passed to `__call__`."""
def __init__(self):
self.outputs = []
def __call__(self, output):
self.outputs.append(output)
train_output_recorder = OutputRecorderAction()
eval_output_recorder = OutputRecorderAction()
test_controller = controller.Controller(
trainer=test_runner,
evaluator=test_runner,
train_actions=[train_output_recorder],
eval_actions=[eval_output_recorder],
global_step=test_runner.global_step,
steps_per_loop=2,
summary_dir=os.path.join(self.model_dir, "summaries/train"),
checkpoint_manager=checkpoint_manager,
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
test_controller.train_and_evaluate(
train_steps=10, eval_steps=2, eval_interval=6)
self.assertLen(train_output_recorder.outputs, 5)
for output in train_output_recorder.outputs:
self.assertIn("loss", output)
self.assertGreaterEqual(output["loss"], 0)
self.assertLen(eval_output_recorder.outputs, 2)
for output in eval_output_recorder.outputs:
self.assertIn("eval_loss", output)
self.assertGreaterEqual(output["eval_loss"], 0)
def test_step_per_loop_callable(self):
test_runner = TestRunner()
checkpoint = tf.train.Checkpoint(
model=test_runner.model, optimizer=test_runner.optimizer)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
self.model_dir,
max_to_keep=None,
step_counter=test_runner.global_step,
checkpoint_interval=10)
def steps_per_loop_fn(global_step):
if global_step > 4:
return 4
return 2
test_controller = controller.Controller(
trainer=test_runner,
global_step=test_runner.global_step,
steps_per_loop=steps_per_loop_fn,
checkpoint_manager=checkpoint_manager
)
test_controller.train(steps=10)
self.assertEqual(test_runner.global_step, 10)
if __name__ == "__main__":
tf.test.main()