|
|
|
|
|
|
|
""" |
|
Assessor analyzes trial's intermediate results (e.g., periodically evaluated accuracy on test dataset) |
|
to tell whether this trial can be early stopped or not. |
|
|
|
See :class:`Assessor`' specification and ``docs/en_US/assessors.rst`` for details. |
|
""" |
|
|
|
from enum import Enum |
|
import logging |
|
|
|
from .recoverable import Recoverable |
|
|
|
__all__ = ['AssessResult', 'Assessor'] |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
class AssessResult(Enum): |
|
""" |
|
Enum class for :meth:`Assessor.assess_trial` return value. |
|
""" |
|
|
|
Good = True |
|
"""The trial works well.""" |
|
|
|
Bad = False |
|
"""The trial works poorly and should be early stopped.""" |
|
|
|
|
|
class Assessor(Recoverable): |
|
""" |
|
Assessor analyzes trial's intermediate results (e.g., periodically evaluated accuracy on test dataset) |
|
to tell whether this trial can be early stopped or not. |
|
|
|
This is the abstract base class for all assessors. |
|
Early stopping algorithms should inherit this class and override :meth:`assess_trial` method, |
|
which receives intermediate results from trials and give an assessing result. |
|
|
|
If :meth:`assess_trial` returns :obj:`AssessResult.Bad` for a trial, |
|
it hints NNI framework that the trial is likely to result in a poor final accuracy, |
|
and therefore should be killed to save resource. |
|
|
|
If an assessor want's to be notified when a trial ends, it can also override :meth:`trial_end`. |
|
|
|
To write a new assessor, you can reference :class:`~nni.medianstop_assessor.MedianstopAssessor`'s code as an example. |
|
|
|
See Also |
|
-------- |
|
Builtin assessors: |
|
:class:`~nni.algorithms.hpo.medianstop_assessor.MedianstopAssessor` |
|
:class:`~nni.algorithms.hpo.curvefitting_assessor.CurvefittingAssessor` |
|
""" |
|
|
|
def assess_trial(self, trial_job_id, trial_history): |
|
""" |
|
Abstract method for determining whether a trial should be killed. Must override. |
|
|
|
The NNI framework has little guarantee on ``trial_history``. |
|
This method is not guaranteed to be invoked for each time ``trial_history`` get updated. |
|
It is also possible that a trial's history keeps updating after receiving a bad result. |
|
And if the trial failed and retried, ``trial_history`` may be inconsistent with its previous value. |
|
|
|
The only guarantee is that ``trial_history`` is always growing. |
|
It will not be empty and will always be longer than previous value. |
|
|
|
This is an example of how :meth:`assess_trial` get invoked sequentially: |
|
|
|
:: |
|
|
|
trial_job_id | trial_history | return value |
|
------------ | --------------- | ------------ |
|
Trial_A | [1.0, 2.0] | Good |
|
Trial_B | [1.5, 1.3] | Bad |
|
Trial_B | [1.5, 1.3, 1.9] | Good |
|
Trial_A | [0.9, 1.8, 2.3] | Good |
|
|
|
Parameters |
|
---------- |
|
trial_job_id : str |
|
Unique identifier of the trial. |
|
trial_history : list |
|
Intermediate results of this trial. The element type is decided by trial code. |
|
|
|
Returns |
|
------- |
|
AssessResult |
|
:obj:`AssessResult.Good` or :obj:`AssessResult.Bad`. |
|
""" |
|
raise NotImplementedError('Assessor: assess_trial not implemented') |
|
|
|
def trial_end(self, trial_job_id, success): |
|
""" |
|
Abstract method invoked when a trial is completed or terminated. Do nothing by default. |
|
|
|
Parameters |
|
---------- |
|
trial_job_id : str |
|
Unique identifier of the trial. |
|
success : bool |
|
True if the trial successfully completed; False if failed or terminated. |
|
""" |
|
|
|
def load_checkpoint(self): |
|
""" |
|
Internal API under revising, not recommended for end users. |
|
""" |
|
checkpoin_path = self.get_checkpoint_path() |
|
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) |
|
|
|
def save_checkpoint(self): |
|
""" |
|
Internal API under revising, not recommended for end users. |
|
""" |
|
checkpoin_path = self.get_checkpoint_path() |
|
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path) |
|
|
|
def _on_exit(self): |
|
pass |
|
|
|
def _on_error(self): |
|
pass |
|
|