Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for orbit.controller.""" | |
import os | |
from absl import logging | |
from absl.testing import parameterized | |
import numpy as np | |
from orbit import controller | |
from orbit import runner | |
from orbit import standard_runner | |
import orbit.utils | |
import tensorflow as tf, tf_keras | |
def create_model(): | |
x = tf_keras.layers.Input(shape=(3,), name="input") | |
y = tf_keras.layers.Dense(4, name="dense")(x) | |
model = tf_keras.Model(x, y) | |
return model | |
def summaries_with_matching_keyword(keyword, summary_dir): | |
"""Returns summary protos matching given keyword from event file.""" | |
matches = [] | |
event_paths = tf.io.gfile.glob(os.path.join(summary_dir, "events*")) | |
for event in tf.compat.v1.train.summary_iterator(event_paths[-1]): | |
if event.summary is not None: | |
for value in event.summary.value: | |
if keyword in value.tag: | |
matches.append(event.summary) | |
return matches | |
def dataset_fn(ctx): | |
del ctx | |
inputs = np.zeros((10, 3), dtype=np.float32) | |
targets = np.ones((10, 4), dtype=np.float32) | |
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets)) | |
dataset = dataset.repeat(100) | |
dataset = dataset.batch(10, drop_remainder=True) | |
return dataset | |
class TestRunner(standard_runner.StandardTrainer, | |
standard_runner.StandardEvaluator): | |
"""Implements the training and evaluation APIs for the test model.""" | |
def __init__(self, return_numpy=False): | |
self.strategy = tf.distribute.get_strategy() | |
self.model = create_model() | |
self.optimizer = tf_keras.optimizers.RMSprop(learning_rate=0.1) | |
self.global_step = self.optimizer.iterations | |
self.train_loss = tf_keras.metrics.Mean("train_loss", dtype=tf.float32) | |
self.eval_loss = tf_keras.metrics.Mean("eval_loss", dtype=tf.float32) | |
self.return_numpy = return_numpy | |
train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) | |
eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) | |
standard_runner.StandardTrainer.__init__(self, train_dataset) | |
standard_runner.StandardEvaluator.__init__(self, eval_dataset) | |
def train_step(self, iterator): | |
def _replicated_step(inputs): | |
"""Replicated training step.""" | |
inputs, targets = inputs | |
with tf.GradientTape() as tape: | |
outputs = self.model(inputs) | |
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) | |
grads = tape.gradient(loss, self.model.variables) | |
self.optimizer.apply_gradients(zip(grads, self.model.variables)) | |
self.train_loss.update_state(loss) | |
self.strategy.run(_replicated_step, args=(next(iterator),)) | |
def train_loop_end(self): | |
train_loss = self.train_loss.result() | |
return { | |
"loss": train_loss.numpy() if self.return_numpy else train_loss, | |
} | |
def build_eval_dataset(self): | |
return self.strategy.distribute_datasets_from_function(dataset_fn) | |
def eval_begin(self): | |
self.eval_loss.reset_states() | |
def eval_step(self, iterator): | |
def _replicated_step(inputs): | |
"""Replicated evaluation step.""" | |
inputs, targets = inputs | |
outputs = self.model(inputs) | |
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) | |
self.eval_loss.update_state(loss) | |
self.strategy.run(_replicated_step, args=(next(iterator),)) | |
def eval_end(self): | |
eval_loss = self.eval_loss.result() | |
return { | |
"eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss, | |
} | |
class TestEvaluator(standard_runner.StandardEvaluator): | |
"""Implements the training and evaluation APIs for the test model.""" | |
def __init__(self): | |
self.strategy = tf.distribute.get_strategy() | |
self.model = create_model() | |
eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) | |
standard_runner.StandardEvaluator.__init__(self, eval_dataset) | |
def eval_reduce(self, state, output): | |
state.append(output) | |
return state | |
def eval_begin(self): | |
return [] | |
def eval_step(self, iterator): | |
def _replicated_step(inputs): | |
"""Replicated evaluation step.""" | |
inputs, targets = inputs | |
outputs = self.model(inputs) | |
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) | |
return loss | |
per_replica_losses = self.strategy.run( | |
_replicated_step, args=(next(iterator),)) | |
mean_loss = self.strategy.reduce( | |
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) | |
return mean_loss | |
def eval_end(self, outputs): | |
return { | |
"eval_loss": tf.reduce_mean(outputs), | |
} | |
class TestEvaluatorNoOutput(runner.AbstractEvaluator): | |
def evaluate(self, num_steps): | |
pass | |
class TestEvaluatorWithNestedSummary(standard_runner.StandardEvaluator): | |
"""Implements the training and evaluation APIs for the test model.""" | |
def __init__(self): | |
self.strategy = tf.distribute.get_strategy() | |
self.model = create_model() | |
dataset = self.strategy.distribute_datasets_from_function(dataset_fn) | |
dataset2 = self.strategy.distribute_datasets_from_function(dataset_fn) | |
self.loss = tf_keras.metrics.Mean("loss", dtype=tf.float32) | |
self.accuracy = tf_keras.metrics.CategoricalAccuracy( | |
"accuracy", dtype=tf.float32) | |
self.loss2 = tf_keras.metrics.Mean("loss", dtype=tf.float32) | |
self.accuracy2 = tf_keras.metrics.CategoricalAccuracy( | |
"accuracy", dtype=tf.float32) | |
standard_runner.StandardEvaluator.__init__( | |
self, eval_dataset={ | |
"dataset": dataset, | |
"dataset2": dataset2 | |
}) | |
def eval_step(self, iterator): | |
def _replicated_step(loss, accuracy, inputs): | |
"""Replicated evaluation step.""" | |
inputs, targets = inputs | |
outputs = self.model(inputs) | |
loss.update_state(tf_keras.losses.MSE(targets, outputs)) | |
accuracy.update_state(targets, outputs) | |
self.strategy.run( | |
lambda inputs: _replicated_step(self.loss, self.accuracy, inputs), | |
args=(next(iterator["dataset"]),)) | |
self.strategy.run( | |
lambda inputs: _replicated_step(self.loss2, self.accuracy2, inputs), | |
args=(next(iterator["dataset2"]),)) | |
def eval_end(self): | |
return { | |
"dataset": { | |
"loss": self.loss.result(), | |
"accuracy": self.accuracy.result() | |
}, | |
"dataset2": { | |
"loss": self.loss2.result(), | |
"accuracy": self.accuracy2.result() | |
}, | |
} | |
class TestTrainerWithSummaries(standard_runner.StandardTrainer): | |
"""A Trainer model with summaries for testing purposes.""" | |
def __init__(self): | |
self.strategy = tf.distribute.get_strategy() | |
self.model = create_model() | |
self.optimizer = tf_keras.optimizers.RMSprop(learning_rate=0.1) | |
self.global_step = self.optimizer.iterations | |
self.train_loss = tf_keras.metrics.Mean("train_loss", dtype=tf.float32) | |
train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) | |
standard_runner.StandardTrainer.__init__( | |
self, | |
train_dataset, | |
options=standard_runner.StandardTrainerOptions( | |
use_tpu_summary_optimization=True)) | |
def build_train_dataset(self): | |
return self.strategy.distribute_datasets_from_function(dataset_fn) | |
def train_step(self, iterator): | |
def _replicated_step(inputs): | |
"""Replicated training step.""" | |
inputs, targets = inputs | |
with tf.GradientTape() as tape: | |
outputs = self.model(inputs) | |
loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) | |
tf.summary.scalar("loss", loss) | |
grads = tape.gradient(loss, self.model.variables) | |
self.optimizer.apply_gradients(zip(grads, self.model.variables)) | |
self.train_loss.update_state(loss) | |
self.strategy.run(_replicated_step, args=(next(iterator),)) | |
class ControllerTest(tf.test.TestCase, parameterized.TestCase): | |
def setUp(self): | |
super().setUp() | |
self.model_dir = self.get_temp_dir() | |
def test_no_checkpoint(self): | |
test_runner = TestRunner() | |
# No checkpoint manager and no strategy. | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
summary_dir=os.path.join(self.model_dir, "summaries/train"), | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
self.assertEqual(test_runner.global_step, 10) | |
# Loss and accuracy values should be written into summaries. | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) | |
# No checkpoint, so global step starts from 0. | |
test_runner.global_step.assign(0) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
self.assertEqual(test_runner.global_step, 10) | |
self.assertTrue(controller._orbit_api_gauge.get_cell().value()) | |
def test_no_checkpoint_and_summaries(self): | |
test_runner = TestRunner() | |
# No checkpoint + summary directories. | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
self.assertEqual(test_runner.global_step, 10) | |
self.assertTrue(controller._orbit_api_gauge.get_cell().value()) | |
def test_has_checkpoint_no_summaries(self, enable_async_checkpoint_saving): | |
test_runner = TestRunner() | |
# Has checkpoint, but no summary directories. | |
checkpoint = tf.train.Checkpoint(model=test_runner.model) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving, | |
steps_per_loop=2) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
self.assertEqual(test_runner.global_step, 10) | |
self.assertTrue(controller._orbit_api_gauge.get_cell().value()) | |
# No summaries are saved. | |
self.assertEmpty(tf.io.gfile.glob( | |
os.path.join(checkpoint_manager.directory, "events.*"))) | |
def test_has_checkpoint_eval_summary_only( | |
self, enable_async_checkpoint_saving | |
): | |
test_runner = TestRunner() | |
# Has checkpoint, but no summary directories. | |
checkpoint = tf.train.Checkpoint(model=test_runner.model) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), | |
steps_per_loop=2) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
self.assertEqual(test_runner.global_step, 10) | |
# Training summaries are not saved. | |
self.assertEmpty(tf.io.gfile.glob( | |
os.path.join(checkpoint_manager.directory, "events.*"))) | |
# Evaluation summaries are saved. | |
self.assertNotEmpty(tf.io.gfile.glob( | |
os.path.join(self.model_dir, "summaries/eval/events.*"))) | |
def test_restore_from_most_recent_checkpoint( | |
self, enable_async_checkpoint_saving | |
): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint(model=test_runner.model) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=5) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
global_step=test_runner.global_step, | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), | |
steps_per_loop=5) | |
test_controller.train(20) | |
self.assertLen(checkpoint_manager.checkpoints, 4) | |
restored_path = test_controller.restore_checkpoint() | |
self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1]) | |
def test_train_and_evaluate( | |
self, return_numpy, enable_async_checkpoint_saving | |
): | |
test_runner = TestRunner(return_numpy=return_numpy) | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=10) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
summary_dir=os.path.join(self.model_dir, "summaries/train"), | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
# Checkpoints are saved. | |
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) | |
# Loss and accuracy values should be written into summaries. | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) | |
def test_train_only(self, enable_async_checkpoint_saving): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=10) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
summary_dir=os.path.join(self.model_dir, "summaries/train"), | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"), | |
) | |
test_controller.train(steps=10) | |
# Checkpoints are saved. | |
self.assertNotEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) | |
# Only train summaries are written. | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "summaries/train"))) | |
self.assertFalse( | |
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) | |
def test_evaluate_only(self): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint(model=test_runner.model) | |
checkpoint.save(os.path.join(self.model_dir, "ckpt")) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step) | |
test_controller = controller.Controller( | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
checkpoint_manager=checkpoint_manager, | |
summary_dir=os.path.join(self.model_dir, "summaries/train"), | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
eval_results = test_controller.evaluate(steps=2) | |
# Only eval summaries are written | |
self.assertFalse( | |
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) | |
self.assertIn("eval_loss", eval_results) | |
# Tests continuous eval with timeout and timeout_fn. | |
done_file = os.path.join(self.model_dir, "summaries/eval/Done") | |
def timeout_fn(): | |
with tf.io.gfile.GFile(done_file, "w") as f: | |
f.write("DONE") | |
return True | |
test_controller = controller.Controller( | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
checkpoint_manager=checkpoint_manager, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
test_controller.evaluate_continuously( | |
timeout=1, timeout_fn=timeout_fn, steps=2) | |
self.assertNotEmpty(tf.io.gfile.glob(done_file)) | |
def test_no_eval_steps(self): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint(model=test_runner.model) | |
checkpoint.save(os.path.join(self.model_dir, "ckpt")) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step) | |
test_controller = controller.Controller( | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
checkpoint_manager=checkpoint_manager) | |
test_controller.evaluate() | |
def test_already_trained_model(self, enable_async_checkpoint_saving): | |
test_runner = TestRunner() | |
test_runner.global_step.assign(10) | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=10) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving) | |
# `global_step` is already `train_steps`. | |
test_controller.train(steps=10) | |
def test_summaries_inside_train_fn(self): | |
test_runner = TestTrainerWithSummaries() | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
summary_dir=os.path.join(self.model_dir, "summaries/train"), | |
summary_interval=2, | |
checkpoint_manager=checkpoint_manager | |
) | |
test_controller.train(steps=10) | |
# Checkpoints are saved. | |
self.assertEmpty(tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt*"))) | |
# Only train summaries are written. | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/train"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "summaries/train"))) | |
self.assertFalse( | |
tf.io.gfile.exists(os.path.join(self.model_dir, "summaries/eval"))) | |
def test_train_and_evaluate_with_same_summary_dir(self): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
summary_dir=os.path.join(self.model_dir, "summaries"), | |
checkpoint_manager=checkpoint_manager, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries")) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
# Loss and accuracy values should be written into summaries. | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "summaries"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"eval_loss", os.path.join(self.model_dir, "summaries"))) | |
def test_early_stop_on_eval_loss(self): | |
test_runner = TestRunner() | |
class EarlyStopController(controller.Controller): | |
"""A subclass of Controller that supports early stopping.""" | |
def train_and_evaluate(self, | |
train_steps: int = None, | |
eval_steps: int = None, | |
eval_interval: int = None): | |
while self.global_step.numpy() < train_steps: | |
interval = min(train_steps - self.global_step.numpy(), eval_interval) | |
num_steps = self.global_step.numpy() + interval | |
self.train(steps=num_steps, checkpoint_at_completion=False) | |
self._sync_on_async_checkpointing() | |
self.evaluate(steps=eval_steps) | |
# Early stop condition. | |
if test_runner.eval_loss.result() < 0.1: | |
logging.info( | |
"Training early stopped as eval_loss %s is less than 0.1", | |
test_runner.eval_loss.result()) | |
return | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=10) | |
test_controller = EarlyStopController( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
checkpoint_manager=checkpoint_manager) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=6, eval_interval=2) | |
self.assertLess(test_runner.global_step, 10) | |
def test_evaluate_with_loss_output(self): | |
test_evaluator = TestEvaluator() | |
checkpoint = tf.train.Checkpoint(model=test_evaluator.model) | |
checkpoint.save(os.path.join(self.model_dir, "ckpt")) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, self.model_dir, max_to_keep=None) | |
test_controller = controller.Controller( | |
evaluator=test_evaluator, | |
global_step=tf.Variable(0, dtype=tf.int64), | |
checkpoint_manager=checkpoint_manager, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
test_controller.evaluate(steps=5) | |
# Only eval summaries are written | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "summaries/eval"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"eval_loss", os.path.join(self.model_dir, "summaries/eval"))) | |
def test_evaluate_with_no_output(self): | |
test_controller = controller.Controller( | |
evaluator=TestEvaluatorNoOutput(), | |
global_step=tf.Variable(0, dtype=tf.int64), | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
self.assertSameElements(["steps_per_second"], | |
test_controller.evaluate(steps=5).keys()) | |
def test_train_and_evaluate_reset_datasets(self): | |
test_runner = TestRunner() | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=2) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
train_dataset = ( | |
test_runner.strategy.distribute_datasets_from_function(dataset_fn)) | |
eval_dataset = ( | |
test_runner.strategy.distribute_datasets_from_function(dataset_fn)) | |
test_runner.train_dataset = train_dataset | |
test_runner.eval_dataset = eval_dataset | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
def test_eval_and_checkpoint_interval(self, enable_async_checkpoint_saving): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=5) | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=10, | |
checkpoint_manager=checkpoint_manager, | |
enable_async_checkpointing=enable_async_checkpoint_saving, | |
summary_dir=self.model_dir) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=5) | |
# Expect 3 checkpoints to be saved at step: 5, 10. | |
self.assertLen( | |
tf.io.gfile.glob(os.path.join(self.model_dir, "ckpt-*.data*")), 2) | |
# Expect evaluation is performed 2 times at step: 5, 10. | |
self.assertLen( | |
summaries_with_matching_keyword("eval_loss", self.model_dir), 2) | |
def test_evaluate_with_nested_summaries(self, inject_summary_manager): | |
test_evaluator = TestEvaluatorWithNestedSummary() | |
if inject_summary_manager: | |
summary_manager = orbit.utils.SummaryManager( | |
self.model_dir, | |
tf.summary.scalar, | |
global_step=tf.Variable(0, dtype=tf.int64)) | |
else: | |
summary_manager = None | |
test_controller = controller.Controller( | |
evaluator=test_evaluator, | |
global_step=tf.Variable(0, dtype=tf.int64), | |
eval_summary_dir=self.model_dir, | |
summary_manager=summary_manager) | |
test_controller.evaluate(steps=5) | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "dataset"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"accuracy", os.path.join(self.model_dir, "dataset"))) | |
self.assertNotEmpty( | |
tf.io.gfile.listdir(os.path.join(self.model_dir, "dataset2"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"loss", os.path.join(self.model_dir, "dataset2"))) | |
self.assertNotEmpty( | |
summaries_with_matching_keyword( | |
"accuracy", os.path.join(self.model_dir, "dataset2"))) | |
def test_actions(self): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=10) | |
class OutputRecorderAction: | |
"""Simple `Action` that just saves the outputs passed to `__call__`.""" | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, output): | |
self.outputs.append(output) | |
train_output_recorder = OutputRecorderAction() | |
eval_output_recorder = OutputRecorderAction() | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
evaluator=test_runner, | |
train_actions=[train_output_recorder], | |
eval_actions=[eval_output_recorder], | |
global_step=test_runner.global_step, | |
steps_per_loop=2, | |
summary_dir=os.path.join(self.model_dir, "summaries/train"), | |
checkpoint_manager=checkpoint_manager, | |
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval")) | |
test_controller.train_and_evaluate( | |
train_steps=10, eval_steps=2, eval_interval=6) | |
self.assertLen(train_output_recorder.outputs, 5) | |
for output in train_output_recorder.outputs: | |
self.assertIn("loss", output) | |
self.assertGreaterEqual(output["loss"], 0) | |
self.assertLen(eval_output_recorder.outputs, 2) | |
for output in eval_output_recorder.outputs: | |
self.assertIn("eval_loss", output) | |
self.assertGreaterEqual(output["eval_loss"], 0) | |
def test_step_per_loop_callable(self): | |
test_runner = TestRunner() | |
checkpoint = tf.train.Checkpoint( | |
model=test_runner.model, optimizer=test_runner.optimizer) | |
checkpoint_manager = tf.train.CheckpointManager( | |
checkpoint, | |
self.model_dir, | |
max_to_keep=None, | |
step_counter=test_runner.global_step, | |
checkpoint_interval=10) | |
def steps_per_loop_fn(global_step): | |
if global_step > 4: | |
return 4 | |
return 2 | |
test_controller = controller.Controller( | |
trainer=test_runner, | |
global_step=test_runner.global_step, | |
steps_per_loop=steps_per_loop_fn, | |
checkpoint_manager=checkpoint_manager | |
) | |
test_controller.train(steps=10) | |
self.assertEqual(test_runner.global_step, 10) | |
if __name__ == "__main__": | |
tf.test.main() | |