Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
from annotator.mmpkg.mmcv.utils import TORCH_VERSION, digit_version | |
from ...dist_utils import master_only | |
from ..hook import HOOKS | |
from .base import LoggerHook | |
class TensorboardLoggerHook(LoggerHook): | |
def __init__(self, | |
log_dir=None, | |
interval=10, | |
ignore_last=True, | |
reset_flag=False, | |
by_epoch=True): | |
super(TensorboardLoggerHook, self).__init__(interval, ignore_last, | |
reset_flag, by_epoch) | |
self.log_dir = log_dir | |
def before_run(self, runner): | |
super(TensorboardLoggerHook, self).before_run(runner) | |
if (TORCH_VERSION == 'parrots' | |
or digit_version(TORCH_VERSION) < digit_version('1.1')): | |
try: | |
from tensorboardX import SummaryWriter | |
except ImportError: | |
raise ImportError('Please install tensorboardX to use ' | |
'TensorboardLoggerHook.') | |
else: | |
try: | |
from torch.utils.tensorboard import SummaryWriter | |
except ImportError: | |
raise ImportError( | |
'Please run "pip install future tensorboard" to install ' | |
'the dependencies to use torch.utils.tensorboard ' | |
'(applicable to PyTorch 1.1 or higher)') | |
if self.log_dir is None: | |
self.log_dir = osp.join(runner.work_dir, 'tf_logs') | |
self.writer = SummaryWriter(self.log_dir) | |
def log(self, runner): | |
tags = self.get_loggable_tags(runner, allow_text=True) | |
for tag, val in tags.items(): | |
if isinstance(val, str): | |
self.writer.add_text(tag, val, self.get_iter(runner)) | |
else: | |
self.writer.add_scalar(tag, val, self.get_iter(runner)) | |
def after_run(self, runner): | |
self.writer.close() | |