|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Detection model trainer. |
|
|
|
This file provides a generic training method that can be used to train a |
|
DetectionModel. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import functools |
|
import six |
|
from six.moves import range |
|
import tensorflow.compat.v1 as tf |
|
import tf_slim as slim |
|
|
|
from object_detection.builders import optimizer_builder |
|
from object_detection.core import standard_fields as fields |
|
from object_detection.utils import ops as util_ops |
|
from object_detection.utils import variables_helper |
|
from deployment import model_deploy |
|
|
|
|
|
def create_input_queue(create_tensor_dict_fn): |
|
"""Sets up reader, prefetcher and returns input queue. |
|
|
|
Args: |
|
create_tensor_dict_fn: function to create tensor dictionary. |
|
|
|
Returns: |
|
all_dict: A dictionary holds tensors for images, boxes, and targets. |
|
""" |
|
tensor_dict = create_tensor_dict_fn() |
|
all_dict = {} |
|
|
|
num_images = len(tensor_dict[fields.InputDataFields.image]) |
|
all_dict['batch'] = tensor_dict['batch'] |
|
del tensor_dict['batch'] |
|
|
|
for i in range(num_images): |
|
suffix = str(i) |
|
for key, val in tensor_dict.items(): |
|
all_dict[key + suffix] = val[i] |
|
|
|
all_dict[fields.InputDataFields.image + suffix] = tf.to_float( |
|
tf.expand_dims(all_dict[fields.InputDataFields.image + suffix], 0)) |
|
|
|
return all_dict |
|
|
|
|
|
def get_inputs(input_queue, num_classes, merge_multiple_label_boxes=False): |
|
"""Dequeues batch and constructs inputs to object detection model. |
|
|
|
Args: |
|
input_queue: BatchQueue object holding enqueued tensor_dicts. |
|
num_classes: Number of classes. |
|
merge_multiple_label_boxes: Whether to merge boxes with multiple labels |
|
or not. Defaults to false. Merged boxes are represented with a single |
|
box and a k-hot encoding of the multiple labels associated with the |
|
boxes. |
|
|
|
Returns: |
|
images: a list of 3-D float tensor of images. |
|
image_keys: a list of string keys for the images. |
|
locations: a list of tensors of shape [num_boxes, 4] containing the corners |
|
of the groundtruth boxes. |
|
classes: a list of padded one-hot tensors containing target classes. |
|
masks: a list of 3-D float tensors of shape [num_boxes, image_height, |
|
image_width] containing instance masks for objects if present in the |
|
input_queue. Else returns None. |
|
keypoints: a list of 3-D float tensors of shape [num_boxes, num_keypoints, |
|
2] containing keypoints for objects if present in the |
|
input queue. Else returns None. |
|
""" |
|
read_data_list = input_queue |
|
label_id_offset = 1 |
|
|
|
def extract_images_and_targets(read_data): |
|
"""Extract images and targets from the input dict.""" |
|
suffix = 0 |
|
|
|
images = [] |
|
keys = [] |
|
locations = [] |
|
classes = [] |
|
masks = [] |
|
keypoints = [] |
|
|
|
while fields.InputDataFields.image + str(suffix) in read_data: |
|
image = read_data[fields.InputDataFields.image + str(suffix)] |
|
key = '' |
|
if fields.InputDataFields.source_id in read_data: |
|
key = read_data[fields.InputDataFields.source_id + str(suffix)] |
|
location_gt = ( |
|
read_data[fields.InputDataFields.groundtruth_boxes + str(suffix)]) |
|
classes_gt = tf.cast( |
|
read_data[fields.InputDataFields.groundtruth_classes + str(suffix)], |
|
tf.int32) |
|
classes_gt -= label_id_offset |
|
masks_gt = read_data.get( |
|
fields.InputDataFields.groundtruth_instance_masks + str(suffix)) |
|
keypoints_gt = read_data.get( |
|
fields.InputDataFields.groundtruth_keypoints + str(suffix)) |
|
|
|
if merge_multiple_label_boxes: |
|
location_gt, classes_gt, _ = util_ops.merge_boxes_with_multiple_labels( |
|
location_gt, classes_gt, num_classes) |
|
else: |
|
classes_gt = util_ops.padded_one_hot_encoding( |
|
indices=classes_gt, depth=num_classes, left_pad=0) |
|
|
|
|
|
|
|
images.append(image) |
|
keys.append(key) |
|
locations.append(location_gt) |
|
classes.append(classes_gt) |
|
masks.append(masks_gt) |
|
keypoints.append(keypoints_gt) |
|
|
|
suffix += 1 |
|
|
|
return (images, keys, locations, classes, masks, keypoints) |
|
|
|
return extract_images_and_targets(read_data_list) |
|
|
|
|
|
def _create_losses(input_queue, create_model_fn, train_config): |
|
"""Creates loss function for a DetectionModel. |
|
|
|
Args: |
|
input_queue: BatchQueue object holding enqueued tensor_dicts. |
|
create_model_fn: A function to create the DetectionModel. |
|
train_config: a train_pb2.TrainConfig protobuf. |
|
""" |
|
|
|
detection_model = create_model_fn() |
|
(images, _, groundtruth_boxes_list, groundtruth_classes_list, |
|
groundtruth_masks_list, groundtruth_keypoints_list) = get_inputs( |
|
input_queue, detection_model.num_classes, |
|
train_config.merge_multiple_label_boxes) |
|
|
|
preprocessed_images = [] |
|
true_image_shapes = [] |
|
for image in images: |
|
resized_image, true_image_shape = detection_model.preprocess(image) |
|
preprocessed_images.append(resized_image) |
|
true_image_shapes.append(true_image_shape) |
|
|
|
images = tf.concat(preprocessed_images, 0) |
|
true_image_shapes = tf.concat(true_image_shapes, 0) |
|
|
|
if any(mask is None for mask in groundtruth_masks_list): |
|
groundtruth_masks_list = None |
|
if any(keypoints is None for keypoints in groundtruth_keypoints_list): |
|
groundtruth_keypoints_list = None |
|
|
|
detection_model.provide_groundtruth( |
|
groundtruth_boxes_list, groundtruth_classes_list, groundtruth_masks_list, |
|
groundtruth_keypoints_list) |
|
prediction_dict = detection_model.predict(images, true_image_shapes, |
|
input_queue['batch']) |
|
|
|
losses_dict = detection_model.loss(prediction_dict, true_image_shapes) |
|
for loss_tensor in losses_dict.values(): |
|
tf.losses.add_loss(loss_tensor) |
|
|
|
|
|
def get_restore_checkpoint_ops(restore_checkpoints, detection_model, |
|
train_config): |
|
"""Restore checkpoint from saved checkpoints. |
|
|
|
Args: |
|
restore_checkpoints: loaded checkpoints. |
|
detection_model: Object detection model built from config file. |
|
train_config: a train_pb2.TrainConfig protobuf. |
|
|
|
Returns: |
|
restorers: A list ops to init the model from checkpoints. |
|
|
|
""" |
|
restorers = [] |
|
vars_restored = [] |
|
for restore_checkpoint in restore_checkpoints: |
|
var_map = detection_model.restore_map( |
|
fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type) |
|
available_var_map = ( |
|
variables_helper.get_variables_available_in_checkpoint( |
|
var_map, restore_checkpoint)) |
|
for var_name, var in six.iteritems(available_var_map): |
|
if var in vars_restored: |
|
tf.logging.info('Variable %s contained in multiple checkpoints', |
|
var.op.name) |
|
del available_var_map[var_name] |
|
else: |
|
vars_restored.append(var) |
|
|
|
|
|
available_ema_var_map = {} |
|
ckpt_reader = tf.train.NewCheckpointReader(restore_checkpoint) |
|
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map() |
|
for var_name, var in six.iteritems(available_var_map): |
|
var_name_ema = var_name + '/ExponentialMovingAverage' |
|
if var_name_ema in ckpt_vars_to_shape_map: |
|
available_ema_var_map[var_name_ema] = var |
|
else: |
|
available_ema_var_map[var_name] = var |
|
available_var_map = available_ema_var_map |
|
init_saver = tf.train.Saver(available_var_map) |
|
if list(available_var_map.keys()): |
|
restorers.append(init_saver) |
|
else: |
|
tf.logging.info('WARNING: Checkpoint %s has no restorable variables', |
|
restore_checkpoint) |
|
|
|
return restorers |
|
|
|
|
|
def train(create_tensor_dict_fn, |
|
create_model_fn, |
|
train_config, |
|
master, |
|
task, |
|
num_clones, |
|
worker_replicas, |
|
clone_on_cpu, |
|
ps_tasks, |
|
worker_job_name, |
|
is_chief, |
|
train_dir, |
|
graph_hook_fn=None): |
|
"""Training function for detection models. |
|
|
|
Args: |
|
create_tensor_dict_fn: a function to create a tensor input dictionary. |
|
create_model_fn: a function that creates a DetectionModel and generates |
|
losses. |
|
train_config: a train_pb2.TrainConfig protobuf. |
|
master: BNS name of the TensorFlow master to use. |
|
task: The task id of this training instance. |
|
num_clones: The number of clones to run per machine. |
|
worker_replicas: The number of work replicas to train with. |
|
clone_on_cpu: True if clones should be forced to run on CPU. |
|
ps_tasks: Number of parameter server tasks. |
|
worker_job_name: Name of the worker job. |
|
is_chief: Whether this replica is the chief replica. |
|
train_dir: Directory to write checkpoints and training summaries to. |
|
graph_hook_fn: Optional function that is called after the training graph is |
|
completely built. This is helpful to perform additional changes to the |
|
training graph such as optimizing batchnorm. The function should modify |
|
the default graph. |
|
""" |
|
|
|
detection_model = create_model_fn() |
|
|
|
with tf.Graph().as_default(): |
|
|
|
deploy_config = model_deploy.DeploymentConfig( |
|
num_clones=num_clones, |
|
clone_on_cpu=clone_on_cpu, |
|
replica_id=task, |
|
num_replicas=worker_replicas, |
|
num_ps_tasks=ps_tasks, |
|
worker_job_name=worker_job_name) |
|
|
|
|
|
with tf.device(deploy_config.variables_device()): |
|
global_step = slim.create_global_step() |
|
|
|
with tf.device(deploy_config.inputs_device()): |
|
input_queue = create_input_queue(create_tensor_dict_fn) |
|
|
|
|
|
|
|
|
|
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) |
|
global_summaries = set([]) |
|
|
|
model_fn = functools.partial( |
|
_create_losses, |
|
create_model_fn=create_model_fn, |
|
train_config=train_config) |
|
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue]) |
|
first_clone_scope = clones[0].scope |
|
|
|
|
|
|
|
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) |
|
|
|
with tf.device(deploy_config.optimizer_device()): |
|
training_optimizer, optimizer_summary_vars = optimizer_builder.build( |
|
train_config.optimizer) |
|
for var in optimizer_summary_vars: |
|
tf.summary.scalar(var.op.name, var) |
|
|
|
sync_optimizer = None |
|
if train_config.sync_replicas: |
|
training_optimizer = tf.train.SyncReplicasOptimizer( |
|
training_optimizer, |
|
replicas_to_aggregate=train_config.replicas_to_aggregate, |
|
total_num_replicas=train_config.worker_replicas) |
|
sync_optimizer = training_optimizer |
|
|
|
|
|
init_fn = None |
|
if train_config.fine_tune_checkpoint: |
|
restore_checkpoints = [ |
|
path.strip() for path in train_config.fine_tune_checkpoint.split(',') |
|
] |
|
|
|
restorers = get_restore_checkpoint_ops(restore_checkpoints, |
|
detection_model, train_config) |
|
|
|
def initializer_fn(sess): |
|
for i, restorer in enumerate(restorers): |
|
restorer.restore(sess, restore_checkpoints[i]) |
|
|
|
init_fn = initializer_fn |
|
|
|
with tf.device(deploy_config.optimizer_device()): |
|
regularization_losses = ( |
|
None if train_config.add_regularization_loss else []) |
|
total_loss, grads_and_vars = model_deploy.optimize_clones( |
|
clones, |
|
training_optimizer, |
|
regularization_losses=regularization_losses) |
|
total_loss = tf.check_numerics(total_loss, 'LossTensor is inf or nan.') |
|
|
|
|
|
if train_config.bias_grad_multiplier: |
|
biases_regex_list = ['.*/biases'] |
|
grads_and_vars = variables_helper.multiply_gradients_matching_regex( |
|
grads_and_vars, |
|
biases_regex_list, |
|
multiplier=train_config.bias_grad_multiplier) |
|
|
|
|
|
if train_config.gradient_clipping_by_norm > 0: |
|
with tf.name_scope('clip_grads'): |
|
grads_and_vars = slim.learning.clip_gradient_norms( |
|
grads_and_vars, train_config.gradient_clipping_by_norm) |
|
|
|
moving_average_variables = slim.get_model_variables() |
|
variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step) |
|
update_ops.append(variable_averages.apply(moving_average_variables)) |
|
|
|
|
|
grad_updates = training_optimizer.apply_gradients( |
|
grads_and_vars, global_step=global_step) |
|
update_ops.append(grad_updates) |
|
update_op = tf.group(*update_ops, name='update_barrier') |
|
with tf.control_dependencies([update_op]): |
|
train_tensor = tf.identity(total_loss, name='train_op') |
|
|
|
if graph_hook_fn: |
|
with tf.device(deploy_config.variables_device()): |
|
graph_hook_fn() |
|
|
|
|
|
for model_var in slim.get_model_variables(): |
|
global_summaries.add(tf.summary.histogram(model_var.op.name, model_var)) |
|
for loss_tensor in tf.losses.get_losses(): |
|
global_summaries.add(tf.summary.scalar(loss_tensor.op.name, loss_tensor)) |
|
global_summaries.add( |
|
tf.summary.scalar('TotalLoss', tf.losses.get_total_loss())) |
|
|
|
|
|
|
|
summaries |= set( |
|
tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) |
|
summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 'critic_loss')) |
|
summaries |= global_summaries |
|
|
|
|
|
summary_op = tf.summary.merge(list(summaries), name='summary_op') |
|
|
|
|
|
session_config = tf.ConfigProto( |
|
allow_soft_placement=True, log_device_placement=False) |
|
|
|
|
|
keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours |
|
saver = tf.train.Saver( |
|
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours) |
|
|
|
slim.learning.train( |
|
train_tensor, |
|
logdir=train_dir, |
|
master=master, |
|
is_chief=is_chief, |
|
session_config=session_config, |
|
startup_delay_steps=train_config.startup_delay_steps, |
|
init_fn=init_fn, |
|
summary_op=summary_op, |
|
number_of_steps=(train_config.num_steps |
|
if train_config.num_steps else None), |
|
save_summaries_secs=120, |
|
sync_optimizer=sync_optimizer, |
|
saver=saver) |
|
|