Spaces:
Build error
Build error
# A simple torch style logger | |
# (C) Wei YANG 2017 | |
from __future__ import absolute_import | |
# import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('pdf') | |
import matplotlib.pyplot as plt | |
import os | |
import sys | |
import numpy as np | |
__all__ = ['Logger', 'LoggerMonitor', 'savefig'] | |
def savefig(fname, dpi=None): | |
dpi = 150 if dpi == None else dpi | |
plt.savefig(fname, dpi=dpi) | |
def plot_overlap(logger, names=None): | |
names = logger.names if names == None else names | |
numbers = logger.numbers | |
for _, name in enumerate(names): | |
x = np.arange(len(numbers[name])) | |
plt.plot(x, np.asarray(numbers[name])) | |
return [logger.title + '(' + name + ')' for name in names] | |
class Logger(object): | |
'''Save training process to log file with simple plot function.''' | |
def __init__(self, fpath, title=None, resume=False): | |
self.file = None | |
self.resume = resume | |
self.title = '' if title == None else title | |
if fpath is not None: | |
if resume: | |
self.file = open(fpath, 'r') | |
name = self.file.readline() | |
self.names = name.rstrip().split('\t') | |
self.numbers = {} | |
for _, name in enumerate(self.names): | |
self.numbers[name] = [] | |
for numbers in self.file: | |
numbers = numbers.rstrip().split('\t') | |
for i in range(0, len(numbers)): | |
self.numbers[self.names[i]].append(numbers[i]) | |
self.file.close() | |
self.file = open(fpath, 'a') | |
else: | |
self.file = open(fpath, 'w') | |
def set_names(self, names): | |
if self.resume: | |
pass | |
# initialize numbers as empty list | |
self.numbers = {} | |
self.names = names | |
for _, name in enumerate(self.names): | |
self.file.write(name) | |
self.file.write('\t') | |
self.numbers[name] = [] | |
self.file.write('\n') | |
self.file.flush() | |
def append(self, numbers): | |
assert len(self.names) == len(numbers), 'Numbers do not match names' | |
for index, num in enumerate(numbers): | |
self.file.write("{0:.6f}".format(num)) | |
self.file.write('\t') | |
self.numbers[self.names[index]].append(num) | |
self.file.write('\n') | |
self.file.flush() | |
def plot(self, names=None): | |
print 'plot' | |
''' | |
names = self.names if names == None else names | |
numbers = self.numbers | |
for _, name in enumerate(names): | |
x = np.arange(len(numbers[name])) | |
plt.plot(x, np.asarray(numbers[name])) | |
plt.legend([self.title + '(' + name + ')' for name in names]) | |
plt.grid(True) | |
''' | |
def close(self): | |
if self.file is not None: | |
self.file.close() | |
class LoggerMonitor(object): | |
'''Load and visualize multiple logs.''' | |
def __init__ (self, paths): | |
'''paths is a distionary with {name:filepath} pair''' | |
self.loggers = [] | |
for title, path in paths.items(): | |
logger = Logger(path, title=title, resume=True) | |
self.loggers.append(logger) | |
def plot(self, names=None): | |
plt.figure() | |
plt.subplot(121) | |
legend_text = [] | |
for logger in self.loggers: | |
legend_text += plot_overlap(logger, names) | |
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) | |
plt.grid(True) | |
if __name__ == '__main__': | |
# # Example | |
# logger = Logger('test.txt') | |
# logger.set_names(['Train loss', 'Valid loss','Test loss']) | |
# length = 100 | |
# t = np.arange(length) | |
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 | |
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 | |
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 | |
# for i in range(0, length): | |
# logger.append([train_loss[i], valid_loss[i], test_loss[i]]) | |
# logger.plot() | |
# Example: logger monitor | |
paths = { | |
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', | |
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', | |
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', | |
} | |
field = ['Valid Acc.'] | |
monitor = LoggerMonitor(paths) | |
monitor.plot(names=field) | |
savefig('test.eps') |