tgritsaev's picture
Upload 198 files
affcd23 verified
raw
history blame
3.12 kB
import importlib
from datetime import datetime
class TensorboardWriter:
def __init__(self, log_dir, logger, enabled):
self.writer = None
self.selected_module = ""
if enabled:
log_dir = str(log_dir)
# Retrieve vizualization writer.
succeeded = False
for module in ["torch.utils.tensorboard", "tensorboardX"]:
try:
self.writer = importlib.import_module(module).SummaryWriter(log_dir)
succeeded = True
break
except ImportError:
succeeded = False
self.selected_module = module
if not succeeded:
message = (
"Warning: visualization (Tensorboard) is configured to use, but currently not "
"installed on this machine. Please install TensorboardX with "
"'pip install tensorboardx', upgrade PyTorch to version >= 1.1 to use "
"'torch.utils.tensorboard' or turn off the option in the 'config.json' file."
)
logger.warning(message)
self.step = 0
self.mode = ""
self.tb_writer_ftns = {
"add_scalar",
"add_scalars",
"add_image",
"add_images",
"add_audio",
"add_text",
"add_histogram",
"add_pr_curve",
"add_embedding",
}
self.tag_mode_exceptions = {"add_histogram", "add_embedding"}
self.timer = datetime.now()
def set_step(self, step, mode="train"):
self.mode = mode
self.step = step
if step == 0:
self.timer = datetime.now()
else:
duration = datetime.now() - self.timer
self.add_scalar("steps_per_sec", 1 / duration.total_seconds())
self.timer = datetime.now()
def __getattr__(self, name):
"""
If visualization is configured to use:
return add_data() methods of tensorboard with additional information (step, tag) added.
Otherwise:
return a blank function handle that does nothing
"""
if name in self.tb_writer_ftns:
add_data = getattr(self.writer, name, None)
def wrapper(tag, data, *args, **kwargs):
if add_data is not None:
# add mode(train/valid) tag
if name not in self.tag_mode_exceptions:
tag = "{}/{}".format(tag, self.mode)
add_data(tag, data, self.step, *args, **kwargs)
return wrapper
else:
# default action for returning methods defined in this class, set_step() for instance.
try:
attr = object.__getattr__(name)
except AttributeError:
raise AttributeError(
"type object '{}' has no attribute '{}'".format(
self.selected_module, name
)
)
return attr