Spaces:
Runtime error
Runtime error
import io | |
import time | |
import logging | |
from datetime import datetime | |
import numpy as np | |
from torch.utils.tensorboard import SummaryWriter | |
LOGGER_NAME = 'root' | |
LOGGER_DATEFMT = '%Y-%m-%d %H:%M:%S' | |
handler = logging.StreamHandler() | |
logger = logging.getLogger(LOGGER_NAME) | |
logger.setLevel(logging.INFO) | |
logger.addHandler(handler) | |
def add_logging(logs_path, prefix): | |
log_name = prefix + datetime.strftime(datetime.today(), '%Y-%m-%d_%H-%M-%S') + '.log' | |
stdout_log_path = logs_path / log_name | |
fh = logging.FileHandler(str(stdout_log_path)) | |
formatter = logging.Formatter(fmt='(%(levelname)s) %(asctime)s: %(message)s', | |
datefmt=LOGGER_DATEFMT) | |
fh.setFormatter(formatter) | |
logger.addHandler(fh) | |
class TqdmToLogger(io.StringIO): | |
logger = None | |
level = None | |
buf = '' | |
def __init__(self, logger, level=None, mininterval=5): | |
super(TqdmToLogger, self).__init__() | |
self.logger = logger | |
self.level = level or logging.INFO | |
self.mininterval = mininterval | |
self.last_time = 0 | |
def write(self, buf): | |
self.buf = buf.strip('\r\n\t ') | |
def flush(self): | |
if len(self.buf) > 0 and time.time() - self.last_time > self.mininterval: | |
self.logger.log(self.level, self.buf) | |
self.last_time = time.time() | |
class SummaryWriterAvg(SummaryWriter): | |
def __init__(self, *args, dump_period=20, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._dump_period = dump_period | |
self._avg_scalars = dict() | |
def add_scalar(self, tag, value, global_step=None, disable_avg=False): | |
if disable_avg or isinstance(value, (tuple, list, dict)): | |
super().add_scalar(tag, np.array(value), global_step=global_step) | |
else: | |
if tag not in self._avg_scalars: | |
self._avg_scalars[tag] = ScalarAccumulator(self._dump_period) | |
avg_scalar = self._avg_scalars[tag] | |
avg_scalar.add(value) | |
if avg_scalar.is_full(): | |
super().add_scalar(tag, avg_scalar.value, | |
global_step=global_step) | |
avg_scalar.reset() | |
class ScalarAccumulator(object): | |
def __init__(self, period): | |
self.sum = 0 | |
self.cnt = 0 | |
self.period = period | |
def add(self, value): | |
self.sum += value | |
self.cnt += 1 | |
def value(self): | |
if self.cnt > 0: | |
return self.sum / self.cnt | |
else: | |
return 0 | |
def reset(self): | |
self.cnt = 0 | |
self.sum = 0 | |
def is_full(self): | |
return self.cnt >= self.period | |
def __len__(self): | |
return self.cnt | |