|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for object detection model library.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import tempfile |
|
import unittest |
|
import numpy as np |
|
import six |
|
import tensorflow.compat.v1 as tf |
|
|
|
from object_detection import inputs |
|
from object_detection import model_hparams |
|
from object_detection import model_lib_v2 |
|
from object_detection.builders import model_builder |
|
from object_detection.core import model |
|
from object_detection.protos import train_pb2 |
|
from object_detection.utils import config_util |
|
from object_detection.utils import tf_version |
|
|
|
if six.PY2: |
|
import mock |
|
else: |
|
from unittest import mock |
|
|
|
|
|
|
|
MODEL_NAME_FOR_TEST = 'ssd_mobilenet_v2_pets_keras' |
|
|
|
|
|
def _get_data_path(): |
|
"""Returns an absolute path to TFRecord file.""" |
|
return os.path.join(tf.resource_loader.get_data_files_path(), 'test_data', |
|
'pets_examples.record') |
|
|
|
|
|
def get_pipeline_config_path(model_name): |
|
"""Returns path to the local pipeline config file.""" |
|
return os.path.join(tf.resource_loader.get_data_files_path(), 'samples', |
|
'configs', model_name + '.config') |
|
|
|
|
|
def _get_labelmap_path(): |
|
"""Returns an absolute path to label map file.""" |
|
return os.path.join(tf.resource_loader.get_data_files_path(), 'data', |
|
'pet_label_map.pbtxt') |
|
|
|
|
|
def _get_config_kwarg_overrides(): |
|
"""Returns overrides to the configs that insert the correct local paths.""" |
|
data_path = _get_data_path() |
|
label_map_path = _get_labelmap_path() |
|
return { |
|
'train_input_path': data_path, |
|
'eval_input_path': data_path, |
|
'label_map_path': label_map_path |
|
} |
|
|
|
|
|
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') |
|
class ModelLibTest(tf.test.TestCase): |
|
|
|
@classmethod |
|
def setUpClass(cls): |
|
tf.keras.backend.clear_session() |
|
|
|
def test_train_loop_then_eval_loop(self): |
|
"""Tests that Estimator and input function are constructed correctly.""" |
|
hparams = model_hparams.create_hparams( |
|
hparams_overrides='load_pretrained=false') |
|
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) |
|
config_kwarg_overrides = _get_config_kwarg_overrides() |
|
model_dir = tf.test.get_temp_dir() |
|
|
|
train_steps = 2 |
|
model_lib_v2.train_loop( |
|
hparams, |
|
pipeline_config_path, |
|
model_dir=model_dir, |
|
train_steps=train_steps, |
|
checkpoint_every_n=1, |
|
**config_kwarg_overrides) |
|
|
|
model_lib_v2.eval_continuously( |
|
hparams, |
|
pipeline_config_path, |
|
model_dir=model_dir, |
|
checkpoint_dir=model_dir, |
|
train_steps=train_steps, |
|
wait_interval=1, |
|
timeout=10, |
|
**config_kwarg_overrides) |
|
|
|
|
|
class SimpleModel(model.DetectionModel): |
|
"""A model with a single weight vector.""" |
|
|
|
def __init__(self, num_classes=1): |
|
super(SimpleModel, self).__init__(num_classes) |
|
self.weight = tf.keras.backend.variable(np.ones(10), name='weight') |
|
|
|
def postprocess(self, prediction_dict, true_image_shapes): |
|
return {} |
|
|
|
def updates(self): |
|
return [] |
|
|
|
def restore_map(self, *args, **kwargs): |
|
return {'model': self} |
|
|
|
def preprocess(self, _): |
|
return tf.zeros((1, 128, 128, 3)), tf.constant([[128, 128, 3]]) |
|
|
|
def provide_groundtruth(self, *args, **kwargs): |
|
pass |
|
|
|
def predict(self, pred_inputs, true_image_shapes): |
|
return {'prediction': |
|
tf.abs(tf.reduce_sum(self.weight) * tf.reduce_sum(pred_inputs))} |
|
|
|
def loss(self, prediction_dict, _): |
|
return {'loss': tf.reduce_sum(prediction_dict['prediction'])} |
|
|
|
def regularization_losses(self): |
|
return [] |
|
|
|
|
|
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') |
|
class ModelCheckpointTest(tf.test.TestCase): |
|
"""Test for model checkpoint related functionality.""" |
|
|
|
def test_checkpoint_max_to_keep(self): |
|
"""Test that only the most recent checkpoints are kept.""" |
|
|
|
with mock.patch.object( |
|
model_builder, 'build', autospec=True) as mock_builder: |
|
mock_builder.return_value = SimpleModel() |
|
|
|
hparams = model_hparams.create_hparams( |
|
hparams_overrides='load_pretrained=false') |
|
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) |
|
config_kwarg_overrides = _get_config_kwarg_overrides() |
|
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir()) |
|
|
|
model_lib_v2.train_loop( |
|
hparams, pipeline_config_path, model_dir=model_dir, |
|
train_steps=20, checkpoint_every_n=2, checkpoint_max_to_keep=3, |
|
**config_kwarg_overrides |
|
) |
|
ckpt_files = tf.io.gfile.glob(os.path.join(model_dir, 'ckpt-*.index')) |
|
self.assertEqual(len(ckpt_files), 3, |
|
'{} not of length 3.'.format(ckpt_files)) |
|
|
|
|
|
class IncompatibleModel(SimpleModel): |
|
|
|
def restore_map(self, *args, **kwargs): |
|
return {'weight': self.weight} |
|
|
|
|
|
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') |
|
class CheckpointV2Test(tf.test.TestCase): |
|
|
|
def setUp(self): |
|
super(CheckpointV2Test, self).setUp() |
|
|
|
self._model = SimpleModel() |
|
tf.keras.backend.set_value(self._model.weight, np.ones(10) * 42) |
|
ckpt = tf.train.Checkpoint(model=self._model) |
|
|
|
self._test_dir = tf.test.get_temp_dir() |
|
self._ckpt_path = ckpt.save(os.path.join(self._test_dir, 'ckpt')) |
|
tf.keras.backend.set_value(self._model.weight, np.ones(10)) |
|
|
|
pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST) |
|
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path) |
|
configs = config_util.merge_external_params_with_configs( |
|
configs, kwargs_dict=_get_config_kwarg_overrides()) |
|
self._train_input_fn = inputs.create_train_input_fn( |
|
configs['train_config'], |
|
configs['train_input_config'], |
|
configs['model']) |
|
|
|
def test_restore_v2(self): |
|
"""Test that restoring a v2 style checkpoint works.""" |
|
|
|
model_lib_v2.load_fine_tune_checkpoint( |
|
self._model, self._ckpt_path, checkpoint_type='', |
|
checkpoint_version=train_pb2.CheckpointVersion.V2, |
|
load_all_detection_checkpoint_vars=True, |
|
input_dataset=self._train_input_fn(), |
|
unpad_groundtruth_tensors=True) |
|
np.testing.assert_allclose(self._model.weight.numpy(), 42) |
|
|
|
def test_restore_map_incompatible_error(self): |
|
"""Test that restoring an incompatible restore map causes an error.""" |
|
|
|
with self.assertRaisesRegex(TypeError, |
|
r'.*received a \(str -> ResourceVariable\).*'): |
|
model_lib_v2.load_fine_tune_checkpoint( |
|
IncompatibleModel(), self._ckpt_path, checkpoint_type='', |
|
checkpoint_version=train_pb2.CheckpointVersion.V2, |
|
load_all_detection_checkpoint_vars=True, |
|
input_dataset=self._train_input_fn(), |
|
unpad_groundtruth_tensors=True) |
|
|
|
|
|
|