Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""An executor class for running model on TensorFlow 2.0.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.legacy.detection.executor import distributed_executor as executor | |
from official.vision.utils.object_detection import visualization_utils | |
class DetectionDistributedExecutor(executor.DistributedExecutor): | |
"""Detection specific customer training loop executor. | |
Subclasses the DistributedExecutor and adds support for numpy based metrics. | |
""" | |
def __init__(self, | |
predict_post_process_fn=None, | |
trainable_variables_filter=None, | |
**kwargs): | |
super(DetectionDistributedExecutor, self).__init__(**kwargs) | |
if predict_post_process_fn: | |
assert callable(predict_post_process_fn) | |
if trainable_variables_filter: | |
assert callable(trainable_variables_filter) | |
self._predict_post_process_fn = predict_post_process_fn | |
self._trainable_variables_filter = trainable_variables_filter | |
self.eval_steps = tf.Variable( | |
0, | |
trainable=False, | |
dtype=tf.int32, | |
synchronization=tf.VariableSynchronization.ON_READ, | |
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, | |
shape=[]) | |
def _create_replicated_step(self, | |
strategy, | |
model, | |
loss_fn, | |
optimizer, | |
metric=None): | |
trainable_variables = model.trainable_variables | |
if self._trainable_variables_filter: | |
trainable_variables = self._trainable_variables_filter( | |
trainable_variables) | |
logging.info('Filter trainable variables from %d to %d', | |
len(model.trainable_variables), len(trainable_variables)) | |
update_state_fn = lambda labels, outputs: None | |
if isinstance(metric, tf_keras.metrics.Metric): | |
update_state_fn = metric.update_state | |
else: | |
logging.error('Detection: train metric is not an instance of ' | |
'tf_keras.metrics.Metric.') | |
def _replicated_step(inputs): | |
"""Replicated training step.""" | |
inputs, labels = inputs | |
with tf.GradientTape() as tape: | |
outputs = model(inputs, training=True) | |
all_losses = loss_fn(labels, outputs) | |
losses = {} | |
for k, v in all_losses.items(): | |
losses[k] = tf.reduce_mean(v) | |
per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync | |
update_state_fn(labels, outputs) | |
grads = tape.gradient(per_replica_loss, trainable_variables) | |
clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0) | |
optimizer.apply_gradients(zip(clipped_grads, trainable_variables)) | |
return losses | |
return _replicated_step | |
def _create_test_step(self, strategy, model, metric): | |
"""Creates a distributed test step.""" | |
def test_step(iterator, eval_steps): | |
"""Calculates evaluation metrics on distributed devices.""" | |
def _test_step_fn(inputs, eval_steps): | |
"""Replicated accuracy calculation.""" | |
inputs, labels = inputs | |
model_outputs = model(inputs, training=False) | |
if self._predict_post_process_fn: | |
labels, prediction_outputs = self._predict_post_process_fn( | |
labels, model_outputs) | |
num_remaining_visualizations = ( | |
self._params.eval.num_images_to_visualize - eval_steps) | |
# If there are remaining number of visualizations that needs to be | |
# done, add next batch outputs for visualization. | |
# | |
# TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only | |
# write correct slice of outputs to summary file. | |
if num_remaining_visualizations > 0: | |
visualization_utils.visualize_images_with_bounding_boxes( | |
inputs, prediction_outputs['detection_boxes'], | |
self.global_train_step, self.eval_summary_writer) | |
return labels, prediction_outputs | |
labels, outputs = strategy.run( | |
_test_step_fn, args=( | |
next(iterator), | |
eval_steps, | |
)) | |
outputs = tf.nest.map_structure(strategy.experimental_local_results, | |
outputs) | |
labels = tf.nest.map_structure(strategy.experimental_local_results, | |
labels) | |
eval_steps.assign_add(self._params.eval.batch_size) | |
return labels, outputs | |
return test_step | |
def _run_evaluation(self, test_step, current_training_step, metric, | |
test_iterator): | |
"""Runs validation steps and aggregate metrics.""" | |
self.eval_steps.assign(0) | |
if not test_iterator or not metric: | |
logging.warning( | |
'Both test_iterator (%s) and metrics (%s) must not be None.', | |
test_iterator, metric) | |
return None | |
logging.info('Running evaluation after step: %s.', current_training_step) | |
while True: | |
try: | |
labels, outputs = test_step(test_iterator, self.eval_steps) | |
if metric: | |
metric.update_state(labels, outputs) | |
except (StopIteration, tf.errors.OutOfRangeError): | |
break | |
metric_result = metric.result() | |
if isinstance(metric, tf_keras.metrics.Metric): | |
metric_result = tf.nest.map_structure(lambda x: x.numpy().astype(float), | |
metric_result) | |
logging.info('Step: [%d] Validation metric = %s', current_training_step, | |
metric_result) | |
return metric_result | |