Spaces:
Build error
Build error
#coding=utf-8 | |
''' | |
Created on 2016-9-27 | |
@author: dengdan | |
''' | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import util | |
def hist(x, title = None, normed = False, show = True, save = False, save_path = None, bin_count = 100, bins = None): | |
x = np.asarray(x) | |
if len(np.shape(x)) > 1: | |
# x = np.reshape(x, np.prod(x.shape)) | |
x = util.np.flatten(x) | |
if bins == None: | |
bins = np.linspace(start = min(x), stop = max(x), num = bin_count, endpoint = True, retstep = False) | |
plt.figure(num = title) | |
plt.hist(x, bins, normed = normed) | |
if save: | |
if save_path is None: | |
raise ValueError | |
path = util.io.join_path(save_path, title + '.png') | |
save_image(path) | |
if show: | |
plt.show() | |
#util.img.imshow(title, path, block = block) | |
def plot_solver_data(solver_path): | |
data = util.io.load(solver_path) | |
training_losses = data.training_losses | |
training_accuracies = data.training_accuracies | |
val_losses = data.val_losses | |
val_accuracies = data.val_accuracies | |
plt.figure(solver_path) | |
n = len(training_losses) | |
x = range(n) | |
plt.plot(x, training_losses, 'r-', label = 'Training Loss') | |
if len(training_accuracies) > 0: | |
plt.plot(x, training_accuracies, 'r--', label = 'Training Accuracy') | |
if len(val_losses) > 0: | |
n = len(val_losses) | |
x = range(n) | |
plt.plot(x, val_losses, 'g-', label = 'Validation Loss') | |
if len(val_accuracies) > 0: | |
plt.plot(x, val_accuracies, 'g--', label = 'Validation Accuracy') | |
plt.legend() | |
plt.show() | |
def rectangle(xy, width, height, color = 'red', linewidth = 1, fill = False, alpha = None, axis = None): | |
""" | |
draw a rectangle on plt axis | |
""" | |
import matplotlib.patches as patches | |
rect = patches.Rectangle( | |
xy = xy, | |
width = width, | |
height = height, | |
alpha = alpha, | |
color = color, | |
fill = fill, | |
linewidth = linewidth | |
) | |
if axis is not None: | |
axis.add_patch(rect) | |
return rect | |
rect = rectangle | |
def maximize_figure(): | |
mng = plt.get_current_fig_manager() | |
mng.full_screen_toggle() | |
def line(xy_start, xy_end, color = 'red', linewidth = 1, alpha = None, axis = None): | |
""" | |
draw a line on plt axis | |
""" | |
from matplotlib.lines import Line2D | |
num = 100 | |
xdata = np.linspace(xy_start[0], xy_end[0], num = num) | |
ydata = np.linspace(xy_start[1], xy_end[1], num = num) | |
line = Line2D( | |
alpha = alpha, | |
color = color, | |
linewidth = linewidth, | |
xdata = xdata, | |
ydata = ydata | |
) | |
if axis is not None: | |
axis.add_line(line) | |
return line | |
def imshow(title = None, img = None, gray = False): | |
show_images([img], [title], gray = gray) | |
def show_images(images, titles = None, shape = None, share_axis = False, | |
bgr2rgb = False, maximized = False, | |
show = True, gray = False, save = False, colorbar = False, | |
path = None, axis_off = False, vertical = False, subtitle = None): | |
if shape == None: | |
if vertical: | |
shape = (len(images), 1) | |
else: | |
shape = (1, len(images)) | |
ret_axes = [] | |
ax0 = None | |
for idx, img in enumerate(images): | |
if bgr2rgb: | |
img = util.img.bgr2rgb(img) | |
loc = (idx / shape[1], idx % shape[1]) | |
if idx == 0: | |
ax = plt.subplot2grid(shape, loc) | |
ax0 = ax | |
else: | |
if share_axis: | |
ax = plt.subplot2grid(shape, loc, sharex = ax0, sharey = ax0) | |
else: | |
ax = plt.subplot2grid(shape, loc) | |
if len(np.shape(img)) == 2 and gray: | |
img_ax = ax.imshow(img, cmap = 'gray') | |
else: | |
img_ax = ax.imshow(img) | |
if len(np.shape(img)) == 2 and colorbar: | |
plt.colorbar(img_ax, ax = ax) | |
if titles != None: | |
ax.set_title(titles[idx]) | |
if axis_off: | |
plt.axis('off') | |
# plt.xticks([]), plt.yticks([]) | |
ret_axes.append(ax) | |
if subtitle is not None: | |
set_subtitle(subtitle) | |
if maximized: | |
maximize_figure() | |
if save: | |
if path is None: | |
raise ValueError('path can not be None when save is True') | |
save_image(path) | |
if show: | |
plt.show() | |
return ret_axes | |
def save_image(path, img = None, dpi = 150): | |
path = util.io.get_absolute_path(path) | |
util.io.make_parent_dir(path) | |
if img is None: | |
plt.gcf().savefig(path, dpi = dpi) | |
else: | |
plt.imsave(path, img) | |
imwrite = save_image | |
def to_ROI(ax, ROI): | |
xy1, xy2 = ROI | |
xmin, ymin = xy1 | |
xmax, ymax = xy2 | |
ax.set_xlim(xmin, xmax) | |
#ax.extent | |
ax.set_ylim(ymax, ymin) | |
def set_subtitle(title, fontsize = 12): | |
plt.gcf().suptitle(title, fontsize=fontsize) | |
def show(maximized = False): | |
if maximized: | |
maximize_figure() | |
plt.show() | |
def draw(): | |
plt.gcf().canvas.draw() | |
def get_random_line_style(): | |
colors = ['r', 'g', 'b'] | |
line_types = ['-']#, '--', '-.', ':'] | |
idx = util.rand.randint(len(colors)) | |
color = colors[idx] | |
idx = util.rand.randint(len(line_types)) | |
line_type = line_types[idx] | |
return color + line_type | |