Cyril666's picture
First model version
6250360
#coding=utf-8
'''
Created on 2016-9-27
@author: dengdan
'''
import matplotlib.pyplot as plt
import numpy as np
import util
def hist(x, title = None, normed = False, show = True, save = False, save_path = None, bin_count = 100, bins = None):
x = np.asarray(x)
if len(np.shape(x)) > 1:
# x = np.reshape(x, np.prod(x.shape))
x = util.np.flatten(x)
if bins == None:
bins = np.linspace(start = min(x), stop = max(x), num = bin_count, endpoint = True, retstep = False)
plt.figure(num = title)
plt.hist(x, bins, normed = normed)
if save:
if save_path is None:
raise ValueError
path = util.io.join_path(save_path, title + '.png')
save_image(path)
if show:
plt.show()
#util.img.imshow(title, path, block = block)
def plot_solver_data(solver_path):
data = util.io.load(solver_path)
training_losses = data.training_losses
training_accuracies = data.training_accuracies
val_losses = data.val_losses
val_accuracies = data.val_accuracies
plt.figure(solver_path)
n = len(training_losses)
x = range(n)
plt.plot(x, training_losses, 'r-', label = 'Training Loss')
if len(training_accuracies) > 0:
plt.plot(x, training_accuracies, 'r--', label = 'Training Accuracy')
if len(val_losses) > 0:
n = len(val_losses)
x = range(n)
plt.plot(x, val_losses, 'g-', label = 'Validation Loss')
if len(val_accuracies) > 0:
plt.plot(x, val_accuracies, 'g--', label = 'Validation Accuracy')
plt.legend()
plt.show()
def rectangle(xy, width, height, color = 'red', linewidth = 1, fill = False, alpha = None, axis = None):
"""
draw a rectangle on plt axis
"""
import matplotlib.patches as patches
rect = patches.Rectangle(
xy = xy,
width = width,
height = height,
alpha = alpha,
color = color,
fill = fill,
linewidth = linewidth
)
if axis is not None:
axis.add_patch(rect)
return rect
rect = rectangle
def maximize_figure():
mng = plt.get_current_fig_manager()
mng.full_screen_toggle()
def line(xy_start, xy_end, color = 'red', linewidth = 1, alpha = None, axis = None):
"""
draw a line on plt axis
"""
from matplotlib.lines import Line2D
num = 100
xdata = np.linspace(xy_start[0], xy_end[0], num = num)
ydata = np.linspace(xy_start[1], xy_end[1], num = num)
line = Line2D(
alpha = alpha,
color = color,
linewidth = linewidth,
xdata = xdata,
ydata = ydata
)
if axis is not None:
axis.add_line(line)
return line
def imshow(title = None, img = None, gray = False):
show_images([img], [title], gray = gray)
def show_images(images, titles = None, shape = None, share_axis = False,
bgr2rgb = False, maximized = False,
show = True, gray = False, save = False, colorbar = False,
path = None, axis_off = False, vertical = False, subtitle = None):
if shape == None:
if vertical:
shape = (len(images), 1)
else:
shape = (1, len(images))
ret_axes = []
ax0 = None
for idx, img in enumerate(images):
if bgr2rgb:
img = util.img.bgr2rgb(img)
loc = (idx / shape[1], idx % shape[1])
if idx == 0:
ax = plt.subplot2grid(shape, loc)
ax0 = ax
else:
if share_axis:
ax = plt.subplot2grid(shape, loc, sharex = ax0, sharey = ax0)
else:
ax = plt.subplot2grid(shape, loc)
if len(np.shape(img)) == 2 and gray:
img_ax = ax.imshow(img, cmap = 'gray')
else:
img_ax = ax.imshow(img)
if len(np.shape(img)) == 2 and colorbar:
plt.colorbar(img_ax, ax = ax)
if titles != None:
ax.set_title(titles[idx])
if axis_off:
plt.axis('off')
# plt.xticks([]), plt.yticks([])
ret_axes.append(ax)
if subtitle is not None:
set_subtitle(subtitle)
if maximized:
maximize_figure()
if save:
if path is None:
raise ValueError('path can not be None when save is True')
save_image(path)
if show:
plt.show()
return ret_axes
def save_image(path, img = None, dpi = 150):
path = util.io.get_absolute_path(path)
util.io.make_parent_dir(path)
if img is None:
plt.gcf().savefig(path, dpi = dpi)
else:
plt.imsave(path, img)
imwrite = save_image
def to_ROI(ax, ROI):
xy1, xy2 = ROI
xmin, ymin = xy1
xmax, ymax = xy2
ax.set_xlim(xmin, xmax)
#ax.extent
ax.set_ylim(ymax, ymin)
def set_subtitle(title, fontsize = 12):
plt.gcf().suptitle(title, fontsize=fontsize)
def show(maximized = False):
if maximized:
maximize_figure()
plt.show()
def draw():
plt.gcf().canvas.draw()
def get_random_line_style():
colors = ['r', 'g', 'b']
line_types = ['-']#, '--', '-.', ':']
idx = util.rand.randint(len(colors))
color = colors[idx]
idx = util.rand.randint(len(line_types))
line_type = line_types[idx]
return color + line_type