""" Statistics calculation utility """ import time import math import sys from onmt.utils.logging import logger class Statistics(object): """ Accumulator for loss statistics. Currently calculates: * accuracy * perplexity * elapsed time """ def __init__(self, loss=0, n_words=0, n_correct=0): self.loss = loss self.n_words = n_words self.n_correct = n_correct self.n_src_words = 0 self.start_time = time.time() @staticmethod def all_gather_stats(stat, max_size=4096): """ Gather a `Statistics` object accross multiple process/nodes Args: stat(:obj:Statistics): the statistics object to gather accross all processes/nodes max_size(int): max buffer size to use Returns: `Statistics`, the update stats object """ stats = Statistics.all_gather_stats_list([stat], max_size=max_size) return stats[0] @staticmethod def all_gather_stats_list(stat_list, max_size=4096): """ Gather a `Statistics` list accross all processes/nodes Args: stat_list(list([`Statistics`])): list of statistics objects to gather accross all processes/nodes max_size(int): max buffer size to use Returns: our_stats(list([`Statistics`])): list of updated stats """ from torch.distributed import get_rank from onmt.utils.distributed import all_gather_list # Get a list of world_size lists with len(stat_list) Statistics objects all_stats = all_gather_list(stat_list, max_size=max_size) our_rank = get_rank() our_stats = all_stats[our_rank] for other_rank, stats in enumerate(all_stats): if other_rank == our_rank: continue for i, stat in enumerate(stats): our_stats[i].update(stat, update_n_src_words=True) return our_stats def update(self, stat, update_n_src_words=False): """ Update statistics by suming values with another `Statistics` object Args: stat: another statistic object update_n_src_words(bool): whether to update (sum) `n_src_words` or not """ self.loss += stat.loss self.n_words += stat.n_words self.n_correct += stat.n_correct if update_n_src_words: self.n_src_words += stat.n_src_words def accuracy(self): """ compute accuracy """ return 100 * (self.n_correct / self.n_words) def xent(self): """ compute cross entropy """ return self.loss / self.n_words def ppl(self): """ compute perplexity """ return math.exp(min(self.loss / self.n_words, 100)) def elapsed_time(self): """ compute elapsed time """ return time.time() - self.start_time def output(self, step, num_steps, learning_rate, start): """Write out statistics to stdout. Args: step (int): current step n_batch (int): total batches start (int): start time of step. """ t = self.elapsed_time() step_fmt = "%2d" % step if num_steps > 0: step_fmt = "%s/%5d" % (step_fmt, num_steps) logger.info( ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") % (step_fmt, self.accuracy(), self.ppl(), self.xent(), learning_rate, self.n_src_words / (t + 1e-5), self.n_words / (t + 1e-5), time.time() - start)) sys.stdout.flush() def log_tensorboard(self, prefix, writer, learning_rate, patience, step): """ display statistics to tensorboard """ t = self.elapsed_time() writer.add_scalar(prefix + "/xent", self.xent(), step) writer.add_scalar(prefix + "/ppl", self.ppl(), step) writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) if patience is not None: writer.add_scalar(prefix + "/patience", patience, step)