|
|
|
from enum import Enum |
|
from onmt.utils.logging import logger |
|
|
|
|
|
class PatienceEnum(Enum): |
|
IMPROVING = 0 |
|
DECREASING = 1 |
|
STOPPED = 2 |
|
|
|
|
|
class Scorer(object): |
|
def __init__(self, best_score, name): |
|
self.best_score = best_score |
|
self.name = name |
|
|
|
def is_improving(self, stats): |
|
raise NotImplementedError() |
|
|
|
def is_decreasing(self, stats): |
|
raise NotImplementedError() |
|
|
|
def update(self, stats): |
|
self.best_score = self._caller(stats) |
|
|
|
def __call__(self, stats, **kwargs): |
|
return self._caller(stats) |
|
|
|
def _caller(self, stats): |
|
raise NotImplementedError() |
|
|
|
|
|
class PPLScorer(Scorer): |
|
|
|
def __init__(self): |
|
super(PPLScorer, self).__init__(float("inf"), "ppl") |
|
|
|
def is_improving(self, stats): |
|
return stats.ppl() < self.best_score |
|
|
|
def is_decreasing(self, stats): |
|
return stats.ppl() > self.best_score |
|
|
|
def _caller(self, stats): |
|
return stats.ppl() |
|
|
|
|
|
class AccuracyScorer(Scorer): |
|
|
|
def __init__(self): |
|
super(AccuracyScorer, self).__init__(float("-inf"), "acc") |
|
|
|
def is_improving(self, stats): |
|
return stats.accuracy() > self.best_score |
|
|
|
def is_decreasing(self, stats): |
|
return stats.accuracy() < self.best_score |
|
|
|
def _caller(self, stats): |
|
return stats.accuracy() |
|
|
|
|
|
DEFAULT_SCORERS = [PPLScorer(), AccuracyScorer()] |
|
|
|
|
|
SCORER_BUILDER = { |
|
"ppl": PPLScorer, |
|
"accuracy": AccuracyScorer |
|
} |
|
|
|
|
|
def scorers_from_opts(opt): |
|
if opt.early_stopping_criteria is None: |
|
return DEFAULT_SCORERS |
|
else: |
|
scorers = [] |
|
for criterion in set(opt.early_stopping_criteria): |
|
assert criterion in SCORER_BUILDER.keys(), \ |
|
"Criterion {} not found".format(criterion) |
|
scorers.append(SCORER_BUILDER[criterion]()) |
|
return scorers |
|
|
|
|
|
class EarlyStopping(object): |
|
|
|
def __init__(self, tolerance, scorers=DEFAULT_SCORERS): |
|
""" |
|
Callable class to keep track of early stopping. |
|
|
|
Args: |
|
tolerance(int): number of validation steps without improving |
|
scorer(fn): list of scorers to validate performance on dev |
|
""" |
|
|
|
self.tolerance = tolerance |
|
self.stalled_tolerance = self.tolerance |
|
self.current_tolerance = self.tolerance |
|
self.early_stopping_scorers = scorers |
|
self.status = PatienceEnum.IMPROVING |
|
self.current_step_best = 0 |
|
|
|
def __call__(self, valid_stats, step): |
|
""" |
|
Update the internal state of early stopping mechanism, whether to |
|
continue training or stop the train procedure. |
|
|
|
Checks whether the scores from all pre-chosen scorers improved. If |
|
every metric improve, then the status is switched to improving and the |
|
tolerance is reset. If every metric deteriorate, then the status is |
|
switched to decreasing and the tolerance is also decreased; if the |
|
tolerance reaches 0, then the status is changed to stopped. |
|
Finally, if some improved and others not, then it's considered stalled; |
|
after tolerance number of stalled, the status is switched to stopped. |
|
|
|
:param valid_stats: Statistics of dev set |
|
""" |
|
|
|
if self.status == PatienceEnum.STOPPED: |
|
|
|
return |
|
|
|
if all([scorer.is_improving(valid_stats) for scorer |
|
in self.early_stopping_scorers]): |
|
self._update_increasing(valid_stats, step) |
|
|
|
elif all([scorer.is_decreasing(valid_stats) for scorer |
|
in self.early_stopping_scorers]): |
|
self._update_decreasing() |
|
|
|
else: |
|
self._update_stalled() |
|
|
|
def _update_stalled(self): |
|
self.stalled_tolerance -= 1 |
|
|
|
logger.info( |
|
"Stalled patience: {}/{}".format(self.stalled_tolerance, |
|
self.tolerance)) |
|
|
|
if self.stalled_tolerance == 0: |
|
logger.info( |
|
"Training finished after stalled validations. Early Stop!" |
|
) |
|
self._log_best_step() |
|
|
|
self._decreasing_or_stopped_status_update(self.stalled_tolerance) |
|
|
|
def _update_increasing(self, valid_stats, step): |
|
self.current_step_best = step |
|
for scorer in self.early_stopping_scorers: |
|
logger.info( |
|
"Model is improving {}: {:g} --> {:g}.".format( |
|
scorer.name, scorer.best_score, scorer(valid_stats)) |
|
) |
|
|
|
scorer.update(valid_stats) |
|
|
|
|
|
self.current_tolerance = self.tolerance |
|
self.stalled_tolerance = self.tolerance |
|
|
|
|
|
self.status = PatienceEnum.IMPROVING |
|
|
|
def _update_decreasing(self): |
|
|
|
self.current_tolerance -= 1 |
|
|
|
|
|
logger.info( |
|
"Decreasing patience: {}/{}".format(self.current_tolerance, |
|
self.tolerance) |
|
) |
|
|
|
if self.current_tolerance == 0: |
|
logger.info("Training finished after not improving. Early Stop!") |
|
self._log_best_step() |
|
|
|
self._decreasing_or_stopped_status_update(self.current_tolerance) |
|
|
|
def _log_best_step(self): |
|
logger.info("Best model found at step {}".format( |
|
self.current_step_best)) |
|
|
|
def _decreasing_or_stopped_status_update(self, tolerance): |
|
self.status = PatienceEnum.DECREASING \ |
|
if tolerance > 0 \ |
|
else PatienceEnum.STOPPED |
|
|
|
def is_improving(self): |
|
return self.status == PatienceEnum.IMPROVING |
|
|
|
def has_stopped(self): |
|
return self.status == PatienceEnum.STOPPED |
|
|