deanna-emery's picture
updates
93528c6
raw
history blame
6.27 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An executor class for running model on TensorFlow 2.0."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import tensorflow as tf, tf_keras
from official.legacy.detection.executor import distributed_executor as executor
from official.vision.utils.object_detection import visualization_utils
class DetectionDistributedExecutor(executor.DistributedExecutor):
"""Detection specific customer training loop executor.
Subclasses the DistributedExecutor and adds support for numpy based metrics.
"""
def __init__(self,
predict_post_process_fn=None,
trainable_variables_filter=None,
**kwargs):
super(DetectionDistributedExecutor, self).__init__(**kwargs)
if predict_post_process_fn:
assert callable(predict_post_process_fn)
if trainable_variables_filter:
assert callable(trainable_variables_filter)
self._predict_post_process_fn = predict_post_process_fn
self._trainable_variables_filter = trainable_variables_filter
self.eval_steps = tf.Variable(
0,
trainable=False,
dtype=tf.int32,
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
def _create_replicated_step(self,
strategy,
model,
loss_fn,
optimizer,
metric=None):
trainable_variables = model.trainable_variables
if self._trainable_variables_filter:
trainable_variables = self._trainable_variables_filter(
trainable_variables)
logging.info('Filter trainable variables from %d to %d',
len(model.trainable_variables), len(trainable_variables))
update_state_fn = lambda labels, outputs: None
if isinstance(metric, tf_keras.metrics.Metric):
update_state_fn = metric.update_state
else:
logging.error('Detection: train metric is not an instance of '
'tf_keras.metrics.Metric.')
def _replicated_step(inputs):
"""Replicated training step."""
inputs, labels = inputs
with tf.GradientTape() as tape:
outputs = model(inputs, training=True)
all_losses = loss_fn(labels, outputs)
losses = {}
for k, v in all_losses.items():
losses[k] = tf.reduce_mean(v)
per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
update_state_fn(labels, outputs)
grads = tape.gradient(per_replica_loss, trainable_variables)
clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
optimizer.apply_gradients(zip(clipped_grads, trainable_variables))
return losses
return _replicated_step
def _create_test_step(self, strategy, model, metric):
"""Creates a distributed test step."""
@tf.function
def test_step(iterator, eval_steps):
"""Calculates evaluation metrics on distributed devices."""
def _test_step_fn(inputs, eval_steps):
"""Replicated accuracy calculation."""
inputs, labels = inputs
model_outputs = model(inputs, training=False)
if self._predict_post_process_fn:
labels, prediction_outputs = self._predict_post_process_fn(
labels, model_outputs)
num_remaining_visualizations = (
self._params.eval.num_images_to_visualize - eval_steps)
# If there are remaining number of visualizations that needs to be
# done, add next batch outputs for visualization.
#
# TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
# write correct slice of outputs to summary file.
if num_remaining_visualizations > 0:
visualization_utils.visualize_images_with_bounding_boxes(
inputs, prediction_outputs['detection_boxes'],
self.global_train_step, self.eval_summary_writer)
return labels, prediction_outputs
labels, outputs = strategy.run(
_test_step_fn, args=(
next(iterator),
eval_steps,
))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results,
labels)
eval_steps.assign_add(self._params.eval.batch_size)
return labels, outputs
return test_step
def _run_evaluation(self, test_step, current_training_step, metric,
test_iterator):
"""Runs validation steps and aggregate metrics."""
self.eval_steps.assign(0)
if not test_iterator or not metric:
logging.warning(
'Both test_iterator (%s) and metrics (%s) must not be None.',
test_iterator, metric)
return None
logging.info('Running evaluation after step: %s.', current_training_step)
while True:
try:
labels, outputs = test_step(test_iterator, self.eval_steps)
if metric:
metric.update_state(labels, outputs)
except (StopIteration, tf.errors.OutOfRangeError):
break
metric_result = metric.result()
if isinstance(metric, tf_keras.metrics.Metric):
metric_result = tf.nest.map_structure(lambda x: x.numpy().astype(float),
metric_result)
logging.info('Step: [%d] Validation metric = %s', current_training_step,
metric_result)
return metric_result