|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Script to train the Attention OCR model. |
|
|
|
A simple usage example: |
|
python train.py |
|
""" |
|
import collections |
|
import logging |
|
import tensorflow as tf |
|
from tensorflow.contrib import slim |
|
from tensorflow import app |
|
from tensorflow.python.platform import flags |
|
from tensorflow.contrib.tfprof import model_analyzer |
|
|
|
import data_provider |
|
import common_flags |
|
|
|
FLAGS = flags.FLAGS |
|
common_flags.define() |
|
|
|
|
|
flags.DEFINE_integer('task', 0, |
|
'The Task ID. This value is used when training with ' |
|
'multiple workers to identify each worker.') |
|
|
|
flags.DEFINE_integer('ps_tasks', 0, |
|
'The number of parameter servers. If the value is 0, then' |
|
' the parameters are handled locally by the worker.') |
|
|
|
flags.DEFINE_integer('save_summaries_secs', 60, |
|
'The frequency with which summaries are saved, in ' |
|
'seconds.') |
|
|
|
flags.DEFINE_integer('save_interval_secs', 600, |
|
'Frequency in seconds of saving the model.') |
|
|
|
flags.DEFINE_integer('max_number_of_steps', int(1e10), |
|
'The maximum number of gradient steps.') |
|
|
|
flags.DEFINE_string('checkpoint_inception', '', |
|
'Checkpoint to recover inception weights from.') |
|
|
|
flags.DEFINE_float('clip_gradient_norm', 2.0, |
|
'If greater than 0 then the gradients would be clipped by ' |
|
'it.') |
|
|
|
flags.DEFINE_bool('sync_replicas', False, |
|
'If True will synchronize replicas during training.') |
|
|
|
flags.DEFINE_integer('replicas_to_aggregate', 1, |
|
'The number of gradients updates before updating params.') |
|
|
|
flags.DEFINE_integer('total_num_replicas', 1, |
|
'Total number of worker replicas.') |
|
|
|
flags.DEFINE_integer('startup_delay_steps', 15, |
|
'Number of training steps between replicas startup.') |
|
|
|
flags.DEFINE_boolean('reset_train_dir', False, |
|
'If true will delete all files in the train_log_dir') |
|
|
|
flags.DEFINE_boolean('show_graph_stats', False, |
|
'Output model size stats to stderr.') |
|
|
|
|
|
TrainingHParams = collections.namedtuple('TrainingHParams', [ |
|
'learning_rate', |
|
'optimizer', |
|
'momentum', |
|
'use_augment_input', |
|
]) |
|
|
|
|
|
def get_training_hparams(): |
|
return TrainingHParams( |
|
learning_rate=FLAGS.learning_rate, |
|
optimizer=FLAGS.optimizer, |
|
momentum=FLAGS.momentum, |
|
use_augment_input=FLAGS.use_augment_input) |
|
|
|
|
|
def create_optimizer(hparams): |
|
"""Creates optimized based on the specified flags.""" |
|
if hparams.optimizer == 'momentum': |
|
optimizer = tf.train.MomentumOptimizer( |
|
hparams.learning_rate, momentum=hparams.momentum) |
|
elif hparams.optimizer == 'adam': |
|
optimizer = tf.train.AdamOptimizer(hparams.learning_rate) |
|
elif hparams.optimizer == 'adadelta': |
|
optimizer = tf.train.AdadeltaOptimizer(hparams.learning_rate) |
|
elif hparams.optimizer == 'adagrad': |
|
optimizer = tf.train.AdagradOptimizer(hparams.learning_rate) |
|
elif hparams.optimizer == 'rmsprop': |
|
optimizer = tf.train.RMSPropOptimizer( |
|
hparams.learning_rate, momentum=hparams.momentum) |
|
return optimizer |
|
|
|
|
|
def train(loss, init_fn, hparams): |
|
"""Wraps slim.learning.train to run a training loop. |
|
|
|
Args: |
|
loss: a loss tensor |
|
init_fn: A callable to be executed after all other initialization is done. |
|
hparams: a model hyper parameters |
|
""" |
|
optimizer = create_optimizer(hparams) |
|
|
|
if FLAGS.sync_replicas: |
|
replica_id = tf.constant(FLAGS.task, tf.int32, shape=()) |
|
optimizer = tf.LegacySyncReplicasOptimizer( |
|
opt=optimizer, |
|
replicas_to_aggregate=FLAGS.replicas_to_aggregate, |
|
replica_id=replica_id, |
|
total_num_replicas=FLAGS.total_num_replicas) |
|
sync_optimizer = optimizer |
|
startup_delay_steps = 0 |
|
else: |
|
startup_delay_steps = 0 |
|
sync_optimizer = None |
|
|
|
train_op = slim.learning.create_train_op( |
|
loss, |
|
optimizer, |
|
summarize_gradients=True, |
|
clip_gradient_norm=FLAGS.clip_gradient_norm) |
|
|
|
slim.learning.train( |
|
train_op=train_op, |
|
logdir=FLAGS.train_log_dir, |
|
graph=loss.graph, |
|
master=FLAGS.master, |
|
is_chief=(FLAGS.task == 0), |
|
number_of_steps=FLAGS.max_number_of_steps, |
|
save_summaries_secs=FLAGS.save_summaries_secs, |
|
save_interval_secs=FLAGS.save_interval_secs, |
|
startup_delay_steps=startup_delay_steps, |
|
sync_optimizer=sync_optimizer, |
|
init_fn=init_fn) |
|
|
|
|
|
def prepare_training_dir(): |
|
if not tf.gfile.Exists(FLAGS.train_log_dir): |
|
logging.info('Create a new training directory %s', FLAGS.train_log_dir) |
|
tf.gfile.MakeDirs(FLAGS.train_log_dir) |
|
else: |
|
if FLAGS.reset_train_dir: |
|
logging.info('Reset the training directory %s', FLAGS.train_log_dir) |
|
tf.gfile.DeleteRecursively(FLAGS.train_log_dir) |
|
tf.gfile.MakeDirs(FLAGS.train_log_dir) |
|
else: |
|
logging.info('Use already existing training directory %s', |
|
FLAGS.train_log_dir) |
|
|
|
|
|
def calculate_graph_metrics(): |
|
param_stats = model_analyzer.print_model_analysis( |
|
tf.get_default_graph(), |
|
tfprof_options=model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS) |
|
return param_stats.total_parameters |
|
|
|
|
|
def main(_): |
|
prepare_training_dir() |
|
|
|
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) |
|
model = common_flags.create_model(dataset.num_char_classes, |
|
dataset.max_sequence_length, |
|
dataset.num_of_views, dataset.null_code) |
|
hparams = get_training_hparams() |
|
|
|
|
|
|
|
|
|
device_setter = tf.train.replica_device_setter( |
|
FLAGS.ps_tasks, merge_devices=True) |
|
with tf.device(device_setter): |
|
data = data_provider.get_data( |
|
dataset, |
|
FLAGS.batch_size, |
|
augment=hparams.use_augment_input, |
|
central_crop_size=common_flags.get_crop_size()) |
|
endpoints = model.create_base(data.images, data.labels_one_hot) |
|
total_loss = model.create_loss(data, endpoints) |
|
model.create_summaries(data, endpoints, dataset.charset, is_training=True) |
|
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint, |
|
FLAGS.checkpoint_inception) |
|
if FLAGS.show_graph_stats: |
|
logging.info('Total number of weights in the graph: %s', |
|
calculate_graph_metrics()) |
|
train(total_loss, init_fn, hparams) |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run() |
|
|