sakharamg's picture
Uploading all files
158b61b
""" Report manager utility """
import time
from datetime import datetime
import onmt
from onmt.utils.logging import logger
def build_report_manager(opt, gpu_rank):
if opt.tensorboard and gpu_rank <= 0:
from torch.utils.tensorboard import SummaryWriter
if not hasattr(opt, 'tensorboard_log_dir_dated'):
opt.tensorboard_log_dir_dated = (
opt.tensorboard_log_dir +
datetime.now().strftime("/%b-%d_%H-%M-%S")
)
writer = SummaryWriter(opt.tensorboard_log_dir_dated, comment="Unmt")
else:
writer = None
report_mgr = ReportMgr(opt.report_every, start_time=-1,
tensorboard_writer=writer)
return report_mgr
class ReportMgrBase(object):
"""
Report Manager Base class
Inherited classes should override:
* `_report_training`
* `_report_step`
"""
def __init__(self, report_every, start_time=-1.):
"""
Args:
report_every(int): Report status every this many sentences
start_time(float): manually set report start time. Negative values
means that you will need to set it later or use `start()`
"""
self.report_every = report_every
self.start_time = start_time
def start(self):
self.start_time = time.time()
def log(self, *args, **kwargs):
logger.info(*args, **kwargs)
def report_training(self, step, num_steps, learning_rate, patience,
report_stats, multigpu=False):
"""
This is the user-defined batch-level traing progress
report function.
Args:
step(int): current step count.
num_steps(int): total number of batches.
learning_rate(float): current learning rate.
report_stats(Statistics): old Statistics instance.
Returns:
report_stats(Statistics): updated Statistics instance.
"""
if self.start_time < 0:
raise ValueError("""ReportMgr needs to be started
(set 'start_time' or use 'start()'""")
if step % self.report_every == 0:
if multigpu:
report_stats = \
onmt.utils.Statistics.all_gather_stats(report_stats)
self._report_training(
step, num_steps, learning_rate, patience, report_stats)
return onmt.utils.Statistics()
else:
return report_stats
def _report_training(self, *args, **kwargs):
""" To be overridden """
raise NotImplementedError()
def report_step(self, lr, patience, step, train_stats=None,
valid_stats=None):
"""
Report stats of a step
Args:
lr(float): current learning rate
patience(int): current patience
step(int): current step
train_stats(Statistics): training stats
valid_stats(Statistics): validation stats
"""
self._report_step(
lr, patience, step,
train_stats=train_stats,
valid_stats=valid_stats)
def _report_step(self, *args, **kwargs):
raise NotImplementedError()
class ReportMgr(ReportMgrBase):
def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
"""
A report manager that writes statistics on standard output as well as
(optionally) TensorBoard
Args:
report_every(int): Report status every this many sentences
tensorboard_writer(:obj:`tensorboard.SummaryWriter`):
The TensorBoard Summary writer to use or None
"""
super(ReportMgr, self).__init__(report_every, start_time)
self.tensorboard_writer = tensorboard_writer
def maybe_log_tensorboard(self, stats, prefix, learning_rate,
patience, step):
if self.tensorboard_writer is not None:
stats.log_tensorboard(
prefix, self.tensorboard_writer, learning_rate, patience, step)
def _report_training(self, step, num_steps, learning_rate, patience,
report_stats):
"""
See base class method `ReportMgrBase.report_training`.
"""
report_stats.output(step, num_steps,
learning_rate, self.start_time)
self.maybe_log_tensorboard(report_stats,
"progress",
learning_rate,
patience,
step)
report_stats = onmt.utils.Statistics()
return report_stats
def _report_step(self, lr, patience, step,
train_stats=None,
valid_stats=None):
"""
See base class method `ReportMgrBase.report_step`.
"""
if train_stats is not None:
self.log('Train perplexity: %g' % train_stats.ppl())
self.log('Train accuracy: %g' % train_stats.accuracy())
self.maybe_log_tensorboard(train_stats,
"train",
lr,
patience,
step)
if valid_stats is not None:
self.log('Validation perplexity: %g' % valid_stats.ppl())
self.log('Validation accuracy: %g' % valid_stats.accuracy())
self.maybe_log_tensorboard(valid_stats,
"valid",
lr,
patience,
step)