File size: 4,370 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
""" Statistics calculation utility """
import time
import math
import sys
from onmt.utils.logging import logger
class Statistics(object):
"""
Accumulator for loss statistics.
Currently calculates:
* accuracy
* perplexity
* elapsed time
"""
def __init__(self, loss=0, n_words=0, n_correct=0):
self.loss = loss
self.n_words = n_words
self.n_correct = n_correct
self.n_src_words = 0
self.start_time = time.time()
@staticmethod
def all_gather_stats(stat, max_size=4096):
"""
Gather a `Statistics` object accross multiple process/nodes
Args:
stat(:obj:Statistics): the statistics object to gather
accross all processes/nodes
max_size(int): max buffer size to use
Returns:
`Statistics`, the update stats object
"""
stats = Statistics.all_gather_stats_list([stat], max_size=max_size)
return stats[0]
@staticmethod
def all_gather_stats_list(stat_list, max_size=4096):
"""
Gather a `Statistics` list accross all processes/nodes
Args:
stat_list(list([`Statistics`])): list of statistics objects to
gather accross all processes/nodes
max_size(int): max buffer size to use
Returns:
our_stats(list([`Statistics`])): list of updated stats
"""
from torch.distributed import get_rank
from onmt.utils.distributed import all_gather_list
# Get a list of world_size lists with len(stat_list) Statistics objects
all_stats = all_gather_list(stat_list, max_size=max_size)
our_rank = get_rank()
our_stats = all_stats[our_rank]
for other_rank, stats in enumerate(all_stats):
if other_rank == our_rank:
continue
for i, stat in enumerate(stats):
our_stats[i].update(stat, update_n_src_words=True)
return our_stats
def update(self, stat, update_n_src_words=False):
"""
Update statistics by suming values with another `Statistics` object
Args:
stat: another statistic object
update_n_src_words(bool): whether to update (sum) `n_src_words`
or not
"""
self.loss += stat.loss
self.n_words += stat.n_words
self.n_correct += stat.n_correct
if update_n_src_words:
self.n_src_words += stat.n_src_words
def accuracy(self):
""" compute accuracy """
return 100 * (self.n_correct / self.n_words)
def xent(self):
""" compute cross entropy """
return self.loss / self.n_words
def ppl(self):
""" compute perplexity """
return math.exp(min(self.loss / self.n_words, 100))
def elapsed_time(self):
""" compute elapsed time """
return time.time() - self.start_time
def output(self, step, num_steps, learning_rate, start):
"""Write out statistics to stdout.
Args:
step (int): current step
n_batch (int): total batches
start (int): start time of step.
"""
t = self.elapsed_time()
step_fmt = "%2d" % step
if num_steps > 0:
step_fmt = "%s/%5d" % (step_fmt, num_steps)
logger.info(
("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " +
"lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec")
% (step_fmt,
self.accuracy(),
self.ppl(),
self.xent(),
learning_rate,
self.n_src_words / (t + 1e-5),
self.n_words / (t + 1e-5),
time.time() - start))
sys.stdout.flush()
def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
""" display statistics to tensorboard """
t = self.elapsed_time()
writer.add_scalar(prefix + "/xent", self.xent(), step)
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
writer.add_scalar(prefix + "/lr", learning_rate, step)
if patience is not None:
writer.add_scalar(prefix + "/patience", patience, step)
|