|
|
|
import os.path as osp |
|
|
|
from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version |
|
from ...dist_utils import master_only |
|
from ..hook import HOOKS |
|
from .base import LoggerHook |
|
|
|
|
|
@HOOKS.register_module() |
|
class TensorboardLoggerHook(LoggerHook): |
|
|
|
def __init__(self, |
|
log_dir=None, |
|
interval=10, |
|
ignore_last=True, |
|
reset_flag=False, |
|
by_epoch=True): |
|
super(TensorboardLoggerHook, self).__init__(interval, ignore_last, |
|
reset_flag, by_epoch) |
|
self.log_dir = log_dir |
|
|
|
@master_only |
|
def before_run(self, runner): |
|
super(TensorboardLoggerHook, self).before_run(runner) |
|
if (TORCH_VERSION == 'parrots' |
|
or digit_version(TORCH_VERSION) < digit_version('1.1')): |
|
try: |
|
from tensorboardX import SummaryWriter |
|
except ImportError: |
|
raise ImportError('Please install tensorboardX to use ' |
|
'TensorboardLoggerHook.') |
|
else: |
|
try: |
|
from torch.utils.tensorboard import SummaryWriter |
|
except ImportError: |
|
raise ImportError( |
|
'Please run "pip install future tensorboard" to install ' |
|
'the dependencies to use torch.utils.tensorboard ' |
|
'(applicable to PyTorch 1.1 or higher)') |
|
|
|
if self.log_dir is None: |
|
self.log_dir = osp.join(runner.work_dir, 'tf_logs') |
|
self.writer = SummaryWriter(self.log_dir) |
|
|
|
@master_only |
|
def log(self, runner): |
|
tags = self.get_loggable_tags(runner, allow_text=True) |
|
for tag, val in tags.items(): |
|
if isinstance(val, str): |
|
self.writer.add_text(tag, val, self.get_iter(runner)) |
|
else: |
|
self.writer.add_scalar(tag, val, self.get_iter(runner)) |
|
|
|
@master_only |
|
def after_run(self, runner): |
|
self.writer.close() |
|
|