|
""" Statistics calculation utility """ |
|
import time |
|
import math |
|
import sys |
|
|
|
from onmt.utils.logging import logger |
|
|
|
|
|
class Statistics(object): |
|
""" |
|
Accumulator for loss statistics. |
|
Currently calculates: |
|
|
|
* accuracy |
|
* perplexity |
|
* elapsed time |
|
""" |
|
|
|
def __init__(self, loss=0, n_words=0, n_correct=0): |
|
self.loss = loss |
|
self.n_words = n_words |
|
self.n_correct = n_correct |
|
self.n_src_words = 0 |
|
self.start_time = time.time() |
|
|
|
@staticmethod |
|
def all_gather_stats(stat, max_size=4096): |
|
""" |
|
Gather a `Statistics` object accross multiple process/nodes |
|
|
|
Args: |
|
stat(:obj:Statistics): the statistics object to gather |
|
accross all processes/nodes |
|
max_size(int): max buffer size to use |
|
|
|
Returns: |
|
`Statistics`, the update stats object |
|
""" |
|
stats = Statistics.all_gather_stats_list([stat], max_size=max_size) |
|
return stats[0] |
|
|
|
@staticmethod |
|
def all_gather_stats_list(stat_list, max_size=4096): |
|
""" |
|
Gather a `Statistics` list accross all processes/nodes |
|
|
|
Args: |
|
stat_list(list([`Statistics`])): list of statistics objects to |
|
gather accross all processes/nodes |
|
max_size(int): max buffer size to use |
|
|
|
Returns: |
|
our_stats(list([`Statistics`])): list of updated stats |
|
""" |
|
from torch.distributed import get_rank |
|
from onmt.utils.distributed import all_gather_list |
|
|
|
|
|
all_stats = all_gather_list(stat_list, max_size=max_size) |
|
|
|
our_rank = get_rank() |
|
our_stats = all_stats[our_rank] |
|
for other_rank, stats in enumerate(all_stats): |
|
if other_rank == our_rank: |
|
continue |
|
for i, stat in enumerate(stats): |
|
our_stats[i].update(stat, update_n_src_words=True) |
|
return our_stats |
|
|
|
def update(self, stat, update_n_src_words=False): |
|
""" |
|
Update statistics by suming values with another `Statistics` object |
|
|
|
Args: |
|
stat: another statistic object |
|
update_n_src_words(bool): whether to update (sum) `n_src_words` |
|
or not |
|
|
|
""" |
|
self.loss += stat.loss |
|
self.n_words += stat.n_words |
|
self.n_correct += stat.n_correct |
|
|
|
if update_n_src_words: |
|
self.n_src_words += stat.n_src_words |
|
|
|
def accuracy(self): |
|
""" compute accuracy """ |
|
return 100 * (self.n_correct / self.n_words) |
|
|
|
def xent(self): |
|
""" compute cross entropy """ |
|
return self.loss / self.n_words |
|
|
|
def ppl(self): |
|
""" compute perplexity """ |
|
return math.exp(min(self.loss / self.n_words, 100)) |
|
|
|
def elapsed_time(self): |
|
""" compute elapsed time """ |
|
return time.time() - self.start_time |
|
|
|
def output(self, step, num_steps, learning_rate, start): |
|
"""Write out statistics to stdout. |
|
|
|
Args: |
|
step (int): current step |
|
n_batch (int): total batches |
|
start (int): start time of step. |
|
""" |
|
t = self.elapsed_time() |
|
step_fmt = "%2d" % step |
|
if num_steps > 0: |
|
step_fmt = "%s/%5d" % (step_fmt, num_steps) |
|
logger.info( |
|
("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + |
|
"lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") |
|
% (step_fmt, |
|
self.accuracy(), |
|
self.ppl(), |
|
self.xent(), |
|
learning_rate, |
|
self.n_src_words / (t + 1e-5), |
|
self.n_words / (t + 1e-5), |
|
time.time() - start)) |
|
sys.stdout.flush() |
|
|
|
def log_tensorboard(self, prefix, writer, learning_rate, patience, step): |
|
""" display statistics to tensorboard """ |
|
t = self.elapsed_time() |
|
writer.add_scalar(prefix + "/xent", self.xent(), step) |
|
writer.add_scalar(prefix + "/ppl", self.ppl(), step) |
|
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) |
|
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) |
|
writer.add_scalar(prefix + "/lr", learning_rate, step) |
|
if patience is not None: |
|
writer.add_scalar(prefix + "/patience", patience, step) |
|
|