|
|
|
from annotator.uniformer.mmcv.utils import Registry, is_method_overridden
|
|
|
|
HOOKS = Registry('hook')
|
|
|
|
|
|
class Hook:
|
|
stages = ('before_run', 'before_train_epoch', 'before_train_iter',
|
|
'after_train_iter', 'after_train_epoch', 'before_val_epoch',
|
|
'before_val_iter', 'after_val_iter', 'after_val_epoch',
|
|
'after_run')
|
|
|
|
def before_run(self, runner):
|
|
pass
|
|
|
|
def after_run(self, runner):
|
|
pass
|
|
|
|
def before_epoch(self, runner):
|
|
pass
|
|
|
|
def after_epoch(self, runner):
|
|
pass
|
|
|
|
def before_iter(self, runner):
|
|
pass
|
|
|
|
def after_iter(self, runner):
|
|
pass
|
|
|
|
def before_train_epoch(self, runner):
|
|
self.before_epoch(runner)
|
|
|
|
def before_val_epoch(self, runner):
|
|
self.before_epoch(runner)
|
|
|
|
def after_train_epoch(self, runner):
|
|
self.after_epoch(runner)
|
|
|
|
def after_val_epoch(self, runner):
|
|
self.after_epoch(runner)
|
|
|
|
def before_train_iter(self, runner):
|
|
self.before_iter(runner)
|
|
|
|
def before_val_iter(self, runner):
|
|
self.before_iter(runner)
|
|
|
|
def after_train_iter(self, runner):
|
|
self.after_iter(runner)
|
|
|
|
def after_val_iter(self, runner):
|
|
self.after_iter(runner)
|
|
|
|
def every_n_epochs(self, runner, n):
|
|
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
|
|
|
def every_n_inner_iters(self, runner, n):
|
|
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
|
|
|
def every_n_iters(self, runner, n):
|
|
return (runner.iter + 1) % n == 0 if n > 0 else False
|
|
|
|
def end_of_epoch(self, runner):
|
|
return runner.inner_iter + 1 == len(runner.data_loader)
|
|
|
|
def is_last_epoch(self, runner):
|
|
return runner.epoch + 1 == runner._max_epochs
|
|
|
|
def is_last_iter(self, runner):
|
|
return runner.iter + 1 == runner._max_iters
|
|
|
|
def get_triggered_stages(self):
|
|
trigger_stages = set()
|
|
for stage in Hook.stages:
|
|
if is_method_overridden(stage, Hook, self):
|
|
trigger_stages.add(stage)
|
|
|
|
|
|
|
|
method_stages_map = {
|
|
'before_epoch': ['before_train_epoch', 'before_val_epoch'],
|
|
'after_epoch': ['after_train_epoch', 'after_val_epoch'],
|
|
'before_iter': ['before_train_iter', 'before_val_iter'],
|
|
'after_iter': ['after_train_iter', 'after_val_iter'],
|
|
}
|
|
|
|
for method, map_stages in method_stages_map.items():
|
|
if is_method_overridden(method, Hook, self):
|
|
trigger_stages.update(map_stages)
|
|
|
|
return [stage for stage in Hook.stages if stage in trigger_stages]
|
|
|