|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Training executable for detection models. |
|
|
|
This executable is used to train DetectionModels. There are two ways of |
|
configuring the training job: |
|
|
|
1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file |
|
can be specified by --pipeline_config_path. |
|
|
|
Example usage: |
|
./train \ |
|
--logtostderr \ |
|
--train_dir=path/to/train_dir \ |
|
--pipeline_config_path=pipeline_config.pbtxt |
|
|
|
2) Three configuration files can be provided: a model_pb2.DetectionModel |
|
configuration file to define what type of DetectionModel is being trained, an |
|
input_reader_pb2.InputReader file to specify what training data will be used and |
|
a train_pb2.TrainConfig file to configure training parameters. |
|
|
|
Example usage: |
|
./train \ |
|
--logtostderr \ |
|
--train_dir=path/to/train_dir \ |
|
--model_config_path=model_config.pbtxt \ |
|
--train_config_path=train_config.pbtxt \ |
|
--input_config_path=train_input_config.pbtxt |
|
|
|
""" |
|
|
|
import functools |
|
import json |
|
import os |
|
from absl import flags |
|
import tensorflow.compat.v1 as tf |
|
from lstm_object_detection import model_builder |
|
from lstm_object_detection import trainer |
|
from lstm_object_detection.inputs import seq_dataset_builder |
|
from lstm_object_detection.utils import config_util |
|
from object_detection.builders import preprocessor_builder |
|
|
|
flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.') |
|
flags.DEFINE_integer('task', 0, 'task id') |
|
flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.') |
|
flags.DEFINE_boolean( |
|
'clone_on_cpu', False, |
|
'Force clones to be deployed on CPU. Note that even if ' |
|
'set to False (allowing ops to run on gpu), some ops may ' |
|
'still be run on the CPU if they have no GPU kernel.') |
|
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer ' |
|
'replicas.') |
|
flags.DEFINE_integer( |
|
'ps_tasks', 0, 'Number of parameter server tasks. If None, does not use ' |
|
'a parameter server.') |
|
flags.DEFINE_string( |
|
'train_dir', '', |
|
'Directory to save the checkpoints and training summaries.') |
|
|
|
flags.DEFINE_string( |
|
'pipeline_config_path', '', |
|
'Path to a pipeline_pb2.TrainEvalPipelineConfig config ' |
|
'file. If provided, other configs are ignored') |
|
|
|
flags.DEFINE_string('train_config_path', '', |
|
'Path to a train_pb2.TrainConfig config file.') |
|
flags.DEFINE_string('input_config_path', '', |
|
'Path to an input_reader_pb2.InputReader config file.') |
|
flags.DEFINE_string('model_config_path', '', |
|
'Path to a model_pb2.DetectionModel config file.') |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
def main(_): |
|
assert FLAGS.train_dir, '`train_dir` is missing.' |
|
if FLAGS.task == 0: |
|
tf.gfile.MakeDirs(FLAGS.train_dir) |
|
if FLAGS.pipeline_config_path: |
|
configs = config_util.get_configs_from_pipeline_file( |
|
FLAGS.pipeline_config_path) |
|
if FLAGS.task == 0: |
|
tf.gfile.Copy( |
|
FLAGS.pipeline_config_path, |
|
os.path.join(FLAGS.train_dir, 'pipeline.config'), |
|
overwrite=True) |
|
else: |
|
configs = config_util.get_configs_from_multiple_files( |
|
model_config_path=FLAGS.model_config_path, |
|
train_config_path=FLAGS.train_config_path, |
|
train_input_config_path=FLAGS.input_config_path) |
|
if FLAGS.task == 0: |
|
for name, config in [('model.config', FLAGS.model_config_path), |
|
('train.config', FLAGS.train_config_path), |
|
('input.config', FLAGS.input_config_path)]: |
|
tf.gfile.Copy( |
|
config, os.path.join(FLAGS.train_dir, name), overwrite=True) |
|
|
|
model_config = configs['model'] |
|
lstm_config = configs['lstm_model'] |
|
train_config = configs['train_config'] |
|
input_config = configs['train_input_config'] |
|
|
|
model_fn = functools.partial( |
|
model_builder.build, |
|
model_config=model_config, |
|
lstm_config=lstm_config, |
|
is_training=True) |
|
|
|
def get_next(config, model_config, lstm_config, unroll_length): |
|
data_augmentation_options = [ |
|
preprocessor_builder.build(step) |
|
for step in train_config.data_augmentation_options |
|
] |
|
return seq_dataset_builder.build( |
|
config, |
|
model_config, |
|
lstm_config, |
|
unroll_length, |
|
data_augmentation_options, |
|
batch_size=train_config.batch_size) |
|
|
|
create_input_dict_fn = functools.partial(get_next, input_config, model_config, |
|
lstm_config, |
|
lstm_config.train_unroll_length) |
|
|
|
env = json.loads(os.environ.get('TF_CONFIG', '{}')) |
|
cluster_data = env.get('cluster', None) |
|
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None |
|
task_data = env.get('task', None) or {'type': 'master', 'index': 0} |
|
task_info = type('TaskSpec', (object,), task_data) |
|
|
|
|
|
ps_tasks = 0 |
|
worker_replicas = 1 |
|
worker_job_name = 'lonely_worker' |
|
task = 0 |
|
is_chief = True |
|
master = '' |
|
|
|
if cluster_data and 'worker' in cluster_data: |
|
|
|
worker_replicas = len(cluster_data['worker']) + 1 |
|
if cluster_data and 'ps' in cluster_data: |
|
ps_tasks = len(cluster_data['ps']) |
|
|
|
if worker_replicas > 1 and ps_tasks < 1: |
|
raise ValueError('At least 1 ps task is needed for distributed training.') |
|
|
|
if worker_replicas >= 1 and ps_tasks > 0: |
|
|
|
server = tf.train.Server( |
|
tf.train.ClusterSpec(cluster), |
|
protocol='grpc', |
|
job_name=task_info.type, |
|
task_index=task_info.index) |
|
if task_info.type == 'ps': |
|
server.join() |
|
return |
|
|
|
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index) |
|
task = task_info.index |
|
is_chief = (task_info.type == 'master') |
|
master = server.target |
|
|
|
trainer.train(create_input_dict_fn, model_fn, train_config, master, task, |
|
FLAGS.num_clones, worker_replicas, FLAGS.clone_on_cpu, ps_tasks, |
|
worker_job_name, is_chief, FLAGS.train_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|