|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for object_detection.trainer.""" |
|
import unittest |
|
import tensorflow.compat.v1 as tf |
|
import tf_slim as slim |
|
from google.protobuf import text_format |
|
|
|
from object_detection.core import losses |
|
from object_detection.core import model |
|
from object_detection.core import standard_fields as fields |
|
from object_detection.legacy import trainer |
|
from object_detection.protos import train_pb2 |
|
from object_detection.utils import tf_version |
|
|
|
|
|
NUMBER_OF_CLASSES = 2 |
|
|
|
|
|
def get_input_function(): |
|
"""A function to get test inputs. Returns an image with one box.""" |
|
image = tf.random_uniform([32, 32, 3], dtype=tf.float32) |
|
key = tf.constant('image_000000') |
|
class_label = tf.random_uniform( |
|
[1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32) |
|
box_label = tf.random_uniform( |
|
[1, 4], minval=0.4, maxval=0.6, dtype=tf.float32) |
|
multiclass_scores = tf.random_uniform( |
|
[1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32) |
|
|
|
return { |
|
fields.InputDataFields.image: image, |
|
fields.InputDataFields.key: key, |
|
fields.InputDataFields.groundtruth_classes: class_label, |
|
fields.InputDataFields.groundtruth_boxes: box_label, |
|
fields.InputDataFields.multiclass_scores: multiclass_scores |
|
} |
|
|
|
|
|
class FakeDetectionModel(model.DetectionModel): |
|
"""A simple (and poor) DetectionModel for use in test.""" |
|
|
|
def __init__(self): |
|
super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES) |
|
self._classification_loss = losses.WeightedSigmoidClassificationLoss() |
|
self._localization_loss = losses.WeightedSmoothL1LocalizationLoss() |
|
|
|
def preprocess(self, inputs): |
|
"""Input preprocessing, resizes images to 28x28. |
|
|
|
Args: |
|
inputs: a [batch, height_in, width_in, channels] float32 tensor |
|
representing a batch of images with values between 0 and 255.0. |
|
|
|
Returns: |
|
preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor. |
|
true_image_shapes: int32 tensor of shape [batch, 3] where each row is |
|
of the form [height, width, channels] indicating the shapes |
|
of true images in the resized images, as resized images can be padded |
|
with zeros. |
|
""" |
|
true_image_shapes = [inputs.shape[:-1].as_list() |
|
for _ in range(inputs.shape[-1])] |
|
return tf.image.resize_images(inputs, [28, 28]), true_image_shapes |
|
|
|
def predict(self, preprocessed_inputs, true_image_shapes): |
|
"""Prediction tensors from inputs tensor. |
|
|
|
Args: |
|
preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor. |
|
true_image_shapes: int32 tensor of shape [batch, 3] where each row is |
|
of the form [height, width, channels] indicating the shapes |
|
of true images in the resized images, as resized images can be padded |
|
with zeros. |
|
|
|
Returns: |
|
prediction_dict: a dictionary holding prediction tensors to be |
|
passed to the Loss or Postprocess functions. |
|
""" |
|
flattened_inputs = slim.flatten(preprocessed_inputs) |
|
class_prediction = slim.fully_connected(flattened_inputs, self._num_classes) |
|
box_prediction = slim.fully_connected(flattened_inputs, 4) |
|
|
|
return { |
|
'class_predictions_with_background': tf.reshape( |
|
class_prediction, [-1, 1, self._num_classes]), |
|
'box_encodings': tf.reshape(box_prediction, [-1, 1, 4]) |
|
} |
|
|
|
def postprocess(self, prediction_dict, true_image_shapes, **params): |
|
"""Convert predicted output tensors to final detections. Unused. |
|
|
|
Args: |
|
prediction_dict: a dictionary holding prediction tensors. |
|
true_image_shapes: int32 tensor of shape [batch, 3] where each row is |
|
of the form [height, width, channels] indicating the shapes |
|
of true images in the resized images, as resized images can be padded |
|
with zeros. |
|
**params: Additional keyword arguments for specific implementations of |
|
DetectionModel. |
|
|
|
Returns: |
|
detections: a dictionary with empty fields. |
|
""" |
|
return { |
|
'detection_boxes': None, |
|
'detection_scores': None, |
|
'detection_classes': None, |
|
'num_detections': None |
|
} |
|
|
|
def loss(self, prediction_dict, true_image_shapes): |
|
"""Compute scalar loss tensors with respect to provided groundtruth. |
|
|
|
Calling this function requires that groundtruth tensors have been |
|
provided via the provide_groundtruth function. |
|
|
|
Args: |
|
prediction_dict: a dictionary holding predicted tensors |
|
true_image_shapes: int32 tensor of shape [batch, 3] where each row is |
|
of the form [height, width, channels] indicating the shapes |
|
of true images in the resized images, as resized images can be padded |
|
with zeros. |
|
|
|
Returns: |
|
a dictionary mapping strings (loss names) to scalar tensors representing |
|
loss values. |
|
""" |
|
batch_reg_targets = tf.stack( |
|
self.groundtruth_lists(fields.BoxListFields.boxes)) |
|
batch_cls_targets = tf.stack( |
|
self.groundtruth_lists(fields.BoxListFields.classes)) |
|
weights = tf.constant( |
|
1.0, dtype=tf.float32, |
|
shape=[len(self.groundtruth_lists(fields.BoxListFields.boxes)), 1]) |
|
|
|
location_losses = self._localization_loss( |
|
prediction_dict['box_encodings'], batch_reg_targets, |
|
weights=weights) |
|
cls_losses = self._classification_loss( |
|
prediction_dict['class_predictions_with_background'], batch_cls_targets, |
|
weights=weights) |
|
|
|
loss_dict = { |
|
'localization_loss': tf.reduce_sum(location_losses), |
|
'classification_loss': tf.reduce_sum(cls_losses), |
|
} |
|
return loss_dict |
|
|
|
def regularization_losses(self): |
|
"""Returns a list of regularization losses for this model. |
|
|
|
Returns a list of regularization losses for this model that the estimator |
|
needs to use during training/optimization. |
|
|
|
Returns: |
|
A list of regularization loss tensors. |
|
""" |
|
pass |
|
|
|
def restore_map(self, fine_tune_checkpoint_type='detection'): |
|
"""Returns a map of variables to load from a foreign checkpoint. |
|
|
|
Args: |
|
fine_tune_checkpoint_type: whether to restore from a full detection |
|
checkpoint (with compatible variable names) or to restore from a |
|
classification checkpoint for initialization prior to training. |
|
Valid values: `detection`, `classification`. Default 'detection'. |
|
|
|
Returns: |
|
A dict mapping variable names to variables. |
|
""" |
|
return {var.op.name: var for var in tf.global_variables()} |
|
|
|
def updates(self): |
|
"""Returns a list of update operators for this model. |
|
|
|
Returns a list of update operators for this model that must be executed at |
|
each training step. The estimator's train op needs to have a control |
|
dependency on these updates. |
|
|
|
Returns: |
|
A list of update operators. |
|
""" |
|
pass |
|
|
|
|
|
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.') |
|
class TrainerTest(tf.test.TestCase): |
|
|
|
def test_configure_trainer_and_train_two_steps(self): |
|
train_config_text_proto = """ |
|
optimizer { |
|
adam_optimizer { |
|
learning_rate { |
|
constant_learning_rate { |
|
learning_rate: 0.01 |
|
} |
|
} |
|
} |
|
} |
|
data_augmentation_options { |
|
random_adjust_brightness { |
|
max_delta: 0.2 |
|
} |
|
} |
|
data_augmentation_options { |
|
random_adjust_contrast { |
|
min_delta: 0.7 |
|
max_delta: 1.1 |
|
} |
|
} |
|
num_steps: 2 |
|
""" |
|
train_config = train_pb2.TrainConfig() |
|
text_format.Merge(train_config_text_proto, train_config) |
|
|
|
train_dir = self.get_temp_dir() |
|
|
|
trainer.train( |
|
create_tensor_dict_fn=get_input_function, |
|
create_model_fn=FakeDetectionModel, |
|
train_config=train_config, |
|
master='', |
|
task=0, |
|
num_clones=1, |
|
worker_replicas=1, |
|
clone_on_cpu=True, |
|
ps_tasks=0, |
|
worker_job_name='worker', |
|
is_chief=True, |
|
train_dir=train_dir) |
|
|
|
def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self): |
|
train_config_text_proto = """ |
|
optimizer { |
|
adam_optimizer { |
|
learning_rate { |
|
constant_learning_rate { |
|
learning_rate: 0.01 |
|
} |
|
} |
|
} |
|
} |
|
data_augmentation_options { |
|
random_adjust_brightness { |
|
max_delta: 0.2 |
|
} |
|
} |
|
data_augmentation_options { |
|
random_adjust_contrast { |
|
min_delta: 0.7 |
|
max_delta: 1.1 |
|
} |
|
} |
|
num_steps: 2 |
|
use_multiclass_scores: true |
|
""" |
|
train_config = train_pb2.TrainConfig() |
|
text_format.Merge(train_config_text_proto, train_config) |
|
|
|
train_dir = self.get_temp_dir() |
|
|
|
trainer.train(create_tensor_dict_fn=get_input_function, |
|
create_model_fn=FakeDetectionModel, |
|
train_config=train_config, |
|
master='', |
|
task=0, |
|
num_clones=1, |
|
worker_replicas=1, |
|
clone_on_cpu=True, |
|
ps_tasks=0, |
|
worker_job_name='worker', |
|
is_chief=True, |
|
train_dir=train_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|