from __future__ import absolute_import from __future__ import division from __future__ import print_function """Results object manages distributed reading and writing of results to disk.""" import ast from collections import namedtuple import os import re from six.moves import xrange import tensorflow as tf ShardStats = namedtuple( 'ShardStats', ['num_local_reps_completed', 'max_local_reps', 'finished']) def ge_non_zero(a, b): return a >= b and b > 0 def get_shard_id(file_name): assert file_name[-4:].lower() == '.txt' return int(file_name[file_name.rfind('_') + 1: -4]) class Results(object): """Manages reading and writing training results to disk asynchronously. Each worker writes to its own file, so that there are no race conditions when writing happens. However any worker may read any file, as is the case for `read_all`. Writes are expected to be atomic so that workers will never read incomplete data, and this is likely to be the case on Unix systems. Reading out of date data is fine, as workers calling `read_all` will wait until data from every worker has been written before proceeding. """ file_template = 'experiment_results_{0}.txt' search_regex = r'^experiment_results_([0-9])+\.txt$' def __init__(self, log_dir, shard_id=0): """Construct `Results` instance. Args: log_dir: Where to write results files. shard_id: Unique id for this file (i.e. shard). Each worker that will be writing results should use a different shard id. If there are N shards, each shard should be numbered 0 through N-1. """ # Use different files for workers so that they can write to disk async. assert 0 <= shard_id self.file_name = self.file_template.format(shard_id) self.log_dir = log_dir self.results_file = os.path.join(self.log_dir, self.file_name) def append(self, metrics): """Append results to results list on disk.""" with tf.gfile.FastGFile(self.results_file, 'a') as writer: writer.write(str(metrics) + '\n') def read_this_shard(self): """Read only from this shard.""" return self._read_shard(self.results_file) def _read_shard(self, results_file): """Read only from the given shard file.""" try: with tf.gfile.FastGFile(results_file, 'r') as reader: results = [ast.literal_eval(entry) for entry in reader] except tf.errors.NotFoundError: # No results written to disk yet. Return empty list. return [] return results def _get_max_local_reps(self, shard_results): """Get maximum number of repetitions the given shard needs to complete. Worker working on each shard needs to complete a certain number of runs before it finishes. This method will return that number so that we can determine which shards are still not done. We assume that workers are including a 'max_local_repetitions' value in their results, which should be the total number of repetitions it needs to run. Args: shard_results: Dict mapping metric names to values. This should be read from a shard on disk. Returns: Maximum number of repetitions the given shard needs to complete. """ mlrs = [r['max_local_repetitions'] for r in shard_results] if not mlrs: return 0 for n in mlrs[1:]: assert n == mlrs[0], 'Some reps have different max rep.' return mlrs[0] def read_all(self, num_shards=None): """Read results across all shards, i.e. get global results list. Args: num_shards: (optional) specifies total number of shards. If the caller wants information about which shards are incomplete, provide this argument (so that shards which have yet to be created are still counted as incomplete shards). Otherwise, no information about incomplete shards will be returned. Returns: aggregate: Global list of results (across all shards). shard_stats: List of ShardStats instances, one for each shard. Or None if `num_shards` is None. """ try: all_children = tf.gfile.ListDirectory(self.log_dir) except tf.errors.NotFoundError: if num_shards is None: return [], None return [], [[] for _ in xrange(num_shards)] shard_ids = { get_shard_id(fname): fname for fname in all_children if re.search(self.search_regex, fname)} if num_shards is None: aggregate = [] shard_stats = None for results_file in shard_ids.values(): aggregate.extend(self._read_shard( os.path.join(self.log_dir, results_file))) else: results_per_shard = [None] * num_shards for shard_id in xrange(num_shards): if shard_id in shard_ids: results_file = shard_ids[shard_id] results_per_shard[shard_id] = self._read_shard( os.path.join(self.log_dir, results_file)) else: results_per_shard[shard_id] = [] # Compute shard stats. shard_stats = [] for shard_results in results_per_shard: max_local_reps = self._get_max_local_reps(shard_results) shard_stats.append(ShardStats( num_local_reps_completed=len(shard_results), max_local_reps=max_local_reps, finished=ge_non_zero(len(shard_results), max_local_reps))) # Compute aggregate. aggregate = [ r for shard_results in results_per_shard for r in shard_results] return aggregate, shard_stats