|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluates text classification model.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import math |
|
import time |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
import graphs |
|
|
|
flags = tf.app.flags |
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string('master', '', |
|
'BNS name prefix of the Tensorflow eval master, ' |
|
'or "local".') |
|
flags.DEFINE_string('eval_dir', '/tmp/text_eval', |
|
'Directory where to write event logs.') |
|
flags.DEFINE_string('eval_data', 'test', 'Specify which dataset is used. ' |
|
'("train", "valid", "test") ') |
|
|
|
flags.DEFINE_string('checkpoint_dir', '/tmp/text_train', |
|
'Directory where to read model checkpoints.') |
|
flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run the eval.') |
|
flags.DEFINE_integer('num_examples', 32, 'Number of examples to run.') |
|
flags.DEFINE_bool('run_once', False, 'Whether to run eval only once.') |
|
|
|
|
|
def restore_from_checkpoint(sess, saver): |
|
"""Restore model from checkpoint. |
|
|
|
Args: |
|
sess: Session. |
|
saver: Saver for restoring the checkpoint. |
|
|
|
Returns: |
|
bool: Whether the checkpoint was found and restored |
|
""" |
|
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) |
|
if not ckpt or not ckpt.model_checkpoint_path: |
|
tf.logging.info('No checkpoint found at %s', FLAGS.checkpoint_dir) |
|
return False |
|
|
|
saver.restore(sess, ckpt.model_checkpoint_path) |
|
return True |
|
|
|
|
|
def run_eval(eval_ops, summary_writer, saver): |
|
"""Runs evaluation over FLAGS.num_examples examples. |
|
|
|
Args: |
|
eval_ops: dict<metric name, tuple(value, update_op)> |
|
summary_writer: Summary writer. |
|
saver: Saver. |
|
|
|
Returns: |
|
dict<metric name, value>, with value being the average over all examples. |
|
""" |
|
sv = tf.train.Supervisor( |
|
logdir=FLAGS.eval_dir, saver=None, summary_op=None, summary_writer=None) |
|
with sv.managed_session( |
|
master=FLAGS.master, start_standard_services=False) as sess: |
|
if not restore_from_checkpoint(sess, saver): |
|
return |
|
sv.start_queue_runners(sess) |
|
|
|
metric_names, ops = zip(*eval_ops.items()) |
|
value_ops, update_ops = zip(*ops) |
|
|
|
value_ops_dict = dict(zip(metric_names, value_ops)) |
|
|
|
|
|
num_batches = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) |
|
tf.logging.info('Running %d batches for evaluation.', num_batches) |
|
for i in range(num_batches): |
|
if (i + 1) % 10 == 0: |
|
tf.logging.info('Running batch %d/%d...', i + 1, num_batches) |
|
if (i + 1) % 50 == 0: |
|
_log_values(sess, value_ops_dict) |
|
sess.run(update_ops) |
|
|
|
_log_values(sess, value_ops_dict, summary_writer=summary_writer) |
|
|
|
|
|
def _log_values(sess, value_ops, summary_writer=None): |
|
"""Evaluate, log, and write summaries of the eval metrics in value_ops.""" |
|
metric_names, value_ops = zip(*value_ops.items()) |
|
values = sess.run(value_ops) |
|
|
|
tf.logging.info('Eval metric values:') |
|
summary = tf.summary.Summary() |
|
for name, val in zip(metric_names, values): |
|
summary.value.add(tag=name, simple_value=val) |
|
tf.logging.info('%s = %.3f', name, val) |
|
|
|
if summary_writer is not None: |
|
global_step_val = sess.run(tf.train.get_global_step()) |
|
tf.logging.info('Finished eval for step ' + str(global_step_val)) |
|
summary_writer.add_summary(summary, global_step_val) |
|
|
|
|
|
def main(_): |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
tf.gfile.MakeDirs(FLAGS.eval_dir) |
|
tf.logging.info('Building eval graph...') |
|
output = graphs.get_model().eval_graph(FLAGS.eval_data) |
|
eval_ops, moving_averaged_variables = output |
|
|
|
saver = tf.train.Saver(moving_averaged_variables) |
|
summary_writer = tf.summary.FileWriter( |
|
FLAGS.eval_dir, graph=tf.get_default_graph()) |
|
|
|
while True: |
|
run_eval(eval_ops, summary_writer, saver) |
|
if FLAGS.run_once: |
|
break |
|
time.sleep(FLAGS.eval_interval_secs) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|