|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Creates and runs TF2 object detection models. |
|
|
|
################################## |
|
NOTE: This module has not been fully tested; please bear with us while we iron |
|
out the kinks. |
|
################################## |
|
|
|
When a TPU device is available, this binary uses TPUStrategy. Otherwise, it uses |
|
GPUS with MirroredStrategy/MultiWorkerMirroredStrategy. |
|
|
|
For local training/evaluation run: |
|
PIPELINE_CONFIG_PATH=path/to/pipeline.config |
|
MODEL_DIR=/tmp/model_outputs |
|
NUM_TRAIN_STEPS=10000 |
|
SAMPLE_1_OF_N_EVAL_EXAMPLES=1 |
|
python model_main_tf2.py -- \ |
|
--model_dir=$MODEL_DIR --num_train_steps=$NUM_TRAIN_STEPS \ |
|
--sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \ |
|
--pipeline_config_path=$PIPELINE_CONFIG_PATH \ |
|
--alsologtostderr |
|
""" |
|
from absl import flags |
|
import tensorflow.compat.v2 as tf |
|
from object_detection import model_hparams |
|
from object_detection import model_lib_v2 |
|
|
|
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config ' |
|
'file.') |
|
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.') |
|
flags.DEFINE_bool('eval_on_train_data', False, 'Enable evaluating on train ' |
|
'data (only supported in distributed training).') |
|
flags.DEFINE_integer('sample_1_of_n_eval_examples', None, 'Will sample one of ' |
|
'every n eval input examples, where n is provided.') |
|
flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample ' |
|
'one of every n train input examples for evaluation, ' |
|
'where n is provided. This is only used if ' |
|
'`eval_training_data` is True.') |
|
flags.DEFINE_string( |
|
'hparams_overrides', None, 'Hyperparameter overrides, ' |
|
'represented as a string containing comma-separated ' |
|
'hparam_name=value pairs.') |
|
flags.DEFINE_string( |
|
'model_dir', None, 'Path to output model directory ' |
|
'where event and checkpoint files will be written.') |
|
flags.DEFINE_string( |
|
'checkpoint_dir', None, 'Path to directory holding a checkpoint. If ' |
|
'`checkpoint_dir` is provided, this binary operates in eval-only mode, ' |
|
'writing resulting metrics to `model_dir`.') |
|
|
|
flags.DEFINE_integer('eval_timeout', 3600, 'Number of seconds to wait for an' |
|
'evaluation checkpoint before exiting.') |
|
flags.DEFINE_integer( |
|
'num_workers', 1, 'When num_workers > 1, training uses ' |
|
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses ' |
|
'MirroredStrategy.') |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def main(unused_argv): |
|
flags.mark_flag_as_required('model_dir') |
|
flags.mark_flag_as_required('pipeline_config_path') |
|
tf.config.set_soft_device_placement(True) |
|
|
|
if FLAGS.checkpoint_dir: |
|
model_lib_v2.eval_continuously( |
|
hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), |
|
pipeline_config_path=FLAGS.pipeline_config_path, |
|
model_dir=FLAGS.model_dir, |
|
train_steps=FLAGS.num_train_steps, |
|
sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples, |
|
sample_1_of_n_eval_on_train_examples=( |
|
FLAGS.sample_1_of_n_eval_on_train_examples), |
|
checkpoint_dir=FLAGS.checkpoint_dir, |
|
wait_interval=300, timeout=FLAGS.eval_timeout) |
|
else: |
|
if tf.config.get_visible_devices('TPU'): |
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver() |
|
tf.config.experimental_connect_to_cluster(resolver) |
|
tf.tpu.experimental.initialize_tpu_system(resolver) |
|
strategy = tf.distribute.experimental.TPUStrategy(resolver) |
|
elif FLAGS.num_workers > 1: |
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() |
|
else: |
|
strategy = tf.compat.v2.distribute.MirroredStrategy() |
|
|
|
with strategy.scope(): |
|
model_lib_v2.train_loop( |
|
hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), |
|
pipeline_config_path=FLAGS.pipeline_config_path, |
|
model_dir=FLAGS.model_dir, |
|
train_steps=FLAGS.num_train_steps, |
|
use_tpu=FLAGS.use_tpu) |
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|