Spaces:
Build error
Build error
File size: 4,517 Bytes
6250360 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# A simple torch style logger
# (C) Wei YANG 2017
from __future__ import absolute_import
# import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('pdf')
import matplotlib.pyplot as plt
import os
import sys
import numpy as np
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
def savefig(fname, dpi=None):
dpi = 150 if dpi == None else dpi
plt.savefig(fname, dpi=dpi)
def plot_overlap(logger, names=None):
names = logger.names if names == None else names
numbers = logger.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
return [logger.title + '(' + name + ')' for name in names]
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath, title=None, resume=False):
self.file = None
self.resume = resume
self.title = '' if title == None else title
if fpath is not None:
if resume:
self.file = open(fpath, 'r')
name = self.file.readline()
self.names = name.rstrip().split('\t')
self.numbers = {}
for _, name in enumerate(self.names):
self.numbers[name] = []
for numbers in self.file:
numbers = numbers.rstrip().split('\t')
for i in range(0, len(numbers)):
self.numbers[self.names[i]].append(numbers[i])
self.file.close()
self.file = open(fpath, 'a')
else:
self.file = open(fpath, 'w')
def set_names(self, names):
if self.resume:
pass
# initialize numbers as empty list
self.numbers = {}
self.names = names
for _, name in enumerate(self.names):
self.file.write(name)
self.file.write('\t')
self.numbers[name] = []
self.file.write('\n')
self.file.flush()
def append(self, numbers):
assert len(self.names) == len(numbers), 'Numbers do not match names'
for index, num in enumerate(numbers):
self.file.write("{0:.6f}".format(num))
self.file.write('\t')
self.numbers[self.names[index]].append(num)
self.file.write('\n')
self.file.flush()
def plot(self, names=None):
print 'plot'
'''
names = self.names if names == None else names
numbers = self.numbers
for _, name in enumerate(names):
x = np.arange(len(numbers[name]))
plt.plot(x, np.asarray(numbers[name]))
plt.legend([self.title + '(' + name + ')' for name in names])
plt.grid(True)
'''
def close(self):
if self.file is not None:
self.file.close()
class LoggerMonitor(object):
'''Load and visualize multiple logs.'''
def __init__ (self, paths):
'''paths is a distionary with {name:filepath} pair'''
self.loggers = []
for title, path in paths.items():
logger = Logger(path, title=title, resume=True)
self.loggers.append(logger)
def plot(self, names=None):
plt.figure()
plt.subplot(121)
legend_text = []
for logger in self.loggers:
legend_text += plot_overlap(logger, names)
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.grid(True)
if __name__ == '__main__':
# # Example
# logger = Logger('test.txt')
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
# length = 100
# t = np.arange(length)
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
# for i in range(0, length):
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
# logger.plot()
# Example: logger monitor
paths = {
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
}
field = ['Valid Acc.']
monitor = LoggerMonitor(paths)
monitor.plot(names=field)
savefig('test.eps') |