|
import os
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
from utils import paramUtil
|
|
import math
|
|
import time
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
def mkdir(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path)
|
|
|
|
COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
|
|
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
|
|
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
|
|
|
MISSING_VALUE = -1
|
|
|
|
def save_image(image_numpy, image_path):
|
|
img_pil = Image.fromarray(image_numpy)
|
|
img_pil.save(image_path)
|
|
|
|
|
|
def save_logfile(log_loss, save_path):
|
|
with open(save_path, 'wt') as f:
|
|
for k, v in log_loss.items():
|
|
w_line = k
|
|
for digit in v:
|
|
w_line += ' %.3f' % digit
|
|
f.write(w_line + '\n')
|
|
|
|
|
|
def print_current_loss(start_time, niter_state, total_niters, losses, epoch=None, sub_epoch=None,
|
|
inner_iter=None, tf_ratio=None, sl_steps=None):
|
|
|
|
def as_minutes(s):
|
|
m = math.floor(s / 60)
|
|
s -= m * 60
|
|
return '%dm %ds' % (m, s)
|
|
|
|
def time_since(since, percent):
|
|
now = time.time()
|
|
s = now - since
|
|
es = s / percent
|
|
rs = es - s
|
|
return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
|
|
|
|
if epoch is not None:
|
|
print('ep/it:%2d-%4d niter:%6d' % (epoch, inner_iter, niter_state), end=" ")
|
|
|
|
message = ' %s completed:%3d%%)' % (time_since(start_time, niter_state / total_niters), niter_state / total_niters * 100)
|
|
|
|
|
|
|
|
|
|
for k, v in losses.items():
|
|
message += ' %s: %.4f ' % (k, v)
|
|
|
|
print(message)
|
|
|
|
def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None):
|
|
|
|
def as_minutes(s):
|
|
m = math.floor(s / 60)
|
|
s -= m * 60
|
|
return '%dm %ds' % (m, s)
|
|
|
|
def time_since(since, percent):
|
|
now = time.time()
|
|
s = now - since
|
|
es = s / percent
|
|
rs = es - s
|
|
return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
|
|
|
|
print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ")
|
|
|
|
message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100)
|
|
for k, v in losses.items():
|
|
message += ' %s: %.4f ' % (k, v)
|
|
print(message)
|
|
|
|
|
|
def compose_gif_img_list(img_list, fp_out, duration):
|
|
img, *imgs = [Image.fromarray(np.array(image)) for image in img_list]
|
|
img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False,
|
|
save_all=True, loop=0, duration=duration)
|
|
|
|
|
|
def save_images(visuals, image_path):
|
|
if not os.path.exists(image_path):
|
|
os.makedirs(image_path)
|
|
|
|
for i, (label, img_numpy) in enumerate(visuals.items()):
|
|
img_name = '%d_%s.jpg' % (i, label)
|
|
save_path = os.path.join(image_path, img_name)
|
|
save_image(img_numpy, save_path)
|
|
|
|
|
|
def save_images_test(visuals, image_path, from_name, to_name):
|
|
if not os.path.exists(image_path):
|
|
os.makedirs(image_path)
|
|
|
|
for i, (label, img_numpy) in enumerate(visuals.items()):
|
|
img_name = "%s_%s_%s" % (from_name, to_name, label)
|
|
save_path = os.path.join(image_path, img_name)
|
|
save_image(img_numpy, save_path)
|
|
|
|
|
|
def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)):
|
|
|
|
compose_img = compose_image(img_list, col, row, img_size)
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
img_path = os.path.join(save_dir, img_name)
|
|
|
|
compose_img.save(img_path)
|
|
|
|
|
|
def compose_image(img_list, col, row, img_size):
|
|
to_image = Image.new('RGB', (col * img_size[0], row * img_size[1]))
|
|
for y in range(0, row):
|
|
for x in range(0, col):
|
|
from_img = Image.fromarray(img_list[y * col + x])
|
|
|
|
|
|
paste_area = (x * img_size[0], y*img_size[1],
|
|
(x + 1) * img_size[0], (y + 1) * img_size[1])
|
|
to_image.paste(from_img, paste_area)
|
|
|
|
return to_image
|
|
|
|
|
|
def plot_loss_curve(losses, save_path, intervals=500):
|
|
plt.figure(figsize=(10, 5))
|
|
plt.title("Loss During Training")
|
|
for key in losses.keys():
|
|
plt.plot(list_cut_average(losses[key], intervals), label=key)
|
|
plt.xlabel("Iterations/" + str(intervals))
|
|
plt.ylabel("Loss")
|
|
plt.legend()
|
|
plt.savefig(save_path)
|
|
plt.show()
|
|
|
|
|
|
def list_cut_average(ll, intervals):
|
|
if intervals == 1:
|
|
return ll
|
|
|
|
bins = math.ceil(len(ll) * 1.0 / intervals)
|
|
ll_new = []
|
|
for i in range(bins):
|
|
l_low = intervals * i
|
|
l_high = l_low + intervals
|
|
l_high = l_high if l_high < len(ll) else len(ll)
|
|
ll_new.append(np.mean(ll[l_low:l_high]))
|
|
return ll_new
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|