|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import glob |
|
import datetime |
|
import pickle |
|
import re |
|
import numpy as np |
|
from collections import OrderedDict |
|
import scipy.ndimage |
|
import PIL.Image |
|
|
|
import config |
|
import dataset |
|
import legacy |
|
|
|
|
|
|
|
|
|
|
|
def load_pkl(filename): |
|
with open(filename, 'rb') as file: |
|
return legacy.LegacyUnpickler(file, encoding='latin1').load() |
|
|
|
def save_pkl(obj, filename): |
|
with open(filename, 'wb') as file: |
|
pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
|
|
|
|
|
|
|
def adjust_dynamic_range(data, drange_in, drange_out): |
|
if drange_in != drange_out: |
|
scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0])) |
|
bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) |
|
data = data * scale + bias |
|
return data |
|
|
|
def create_image_grid(images, grid_size=None): |
|
assert images.ndim == 3 or images.ndim == 4 |
|
num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2] |
|
|
|
if grid_size is not None: |
|
grid_w, grid_h = tuple(grid_size) |
|
else: |
|
grid_w = max(int(np.ceil(np.sqrt(num))), 1) |
|
grid_h = max((num - 1) // grid_w + 1, 1) |
|
|
|
grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype) |
|
for idx in range(num): |
|
x = (idx % grid_w) * img_w |
|
y = (idx // grid_w) * img_h |
|
grid[..., y : y + img_h, x : x + img_w] = images[idx] |
|
return grid |
|
|
|
def convert_to_pil_image(image, drange=[0,1]): |
|
assert image.ndim == 2 or image.ndim == 3 |
|
if image.ndim == 3: |
|
if image.shape[0] == 1: |
|
image = image[0] |
|
else: |
|
image = image.transpose(1, 2, 0) |
|
|
|
image = adjust_dynamic_range(image, drange, [0,255]) |
|
image = np.rint(image).clip(0, 255).astype(np.uint8) |
|
format = 'RGB' if image.ndim == 3 else 'L' |
|
return PIL.Image.fromarray(image, format) |
|
|
|
def save_image(image, filename, drange=[0,1], quality=95): |
|
img = convert_to_pil_image(image, drange) |
|
if '.jpg' in filename: |
|
img.save(filename,"JPEG", quality=quality, optimize=True) |
|
else: |
|
img.save(filename) |
|
|
|
def save_image_grid(images, filename, drange=[0,1], grid_size=None): |
|
convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename) |
|
|
|
|
|
|
|
|
|
class OutputLogger(object): |
|
def __init__(self): |
|
self.file = None |
|
self.buffer = '' |
|
|
|
def set_log_file(self, filename, mode='wt'): |
|
assert self.file is None |
|
self.file = open(filename, mode) |
|
if self.buffer is not None: |
|
self.file.write(self.buffer) |
|
self.buffer = None |
|
|
|
def write(self, data): |
|
if self.file is not None: |
|
self.file.write(data) |
|
if self.buffer is not None: |
|
self.buffer += data |
|
|
|
def flush(self): |
|
if self.file is not None: |
|
self.file.flush() |
|
|
|
class TeeOutputStream(object): |
|
def __init__(self, child_streams, autoflush=False): |
|
self.child_streams = child_streams |
|
self.autoflush = autoflush |
|
|
|
def write(self, data): |
|
for stream in self.child_streams: |
|
stream.write(data) |
|
if self.autoflush: |
|
self.flush() |
|
|
|
def flush(self): |
|
for stream in self.child_streams: |
|
stream.flush() |
|
|
|
output_logger = None |
|
|
|
def init_output_logging(): |
|
global output_logger |
|
if output_logger is None: |
|
output_logger = OutputLogger() |
|
sys.stdout = TeeOutputStream([sys.stdout, output_logger], autoflush=True) |
|
sys.stderr = TeeOutputStream([sys.stderr, output_logger], autoflush=True) |
|
|
|
def set_output_log_file(filename, mode='wt'): |
|
if output_logger is not None: |
|
output_logger.set_log_file(filename, mode) |
|
|
|
|
|
|
|
|
|
def create_result_subdir(result_dir, desc): |
|
|
|
|
|
while True: |
|
run_id = 0 |
|
for fname in glob.glob(os.path.join(result_dir, '*')): |
|
try: |
|
fbase = os.path.basename(fname) |
|
ford = int(fbase[:fbase.find('-')]) |
|
run_id = max(run_id, ford + 1) |
|
except ValueError: |
|
pass |
|
|
|
result_subdir = os.path.join(result_dir, '%03d-%s' % (run_id, desc)) |
|
try: |
|
os.makedirs(result_subdir) |
|
break |
|
except OSError: |
|
if os.path.isdir(result_subdir): |
|
continue |
|
raise |
|
|
|
print("Saving results to", result_subdir) |
|
set_output_log_file(os.path.join(result_subdir, 'log.txt')) |
|
|
|
|
|
try: |
|
with open(os.path.join(result_subdir, 'config.txt'), 'wt') as fout: |
|
for k, v in sorted(config.__dict__.items()): |
|
if not k.startswith('_'): |
|
fout.write("%s = %s\n" % (k, str(v))) |
|
except: |
|
pass |
|
|
|
return result_subdir |
|
|
|
def format_time(seconds): |
|
s = int(np.rint(seconds)) |
|
if s < 60: return '%ds' % (s) |
|
elif s < 60*60: return '%dm %02ds' % (s // 60, s % 60) |
|
elif s < 24*60*60: return '%dh %02dm %02ds' % (s // (60*60), (s // 60) % 60, s % 60) |
|
else: return '%dd %02dh %02dm' % (s // (24*60*60), (s // (60*60)) % 24, (s // 60) % 60) |
|
|
|
|
|
|
|
|
|
def locate_result_subdir(run_id_or_result_subdir): |
|
if isinstance(run_id_or_result_subdir, str) and os.path.isdir(run_id_or_result_subdir): |
|
return run_id_or_result_subdir |
|
|
|
searchdirs = [] |
|
searchdirs += [''] |
|
searchdirs += ['results'] |
|
searchdirs += ['networks'] |
|
|
|
for searchdir in searchdirs: |
|
dir = config.result_dir if searchdir == '' else os.path.join(config.result_dir, searchdir) |
|
dir = os.path.join(dir, str(run_id_or_result_subdir)) |
|
if os.path.isdir(dir): |
|
return dir |
|
prefix = '%03d' % run_id_or_result_subdir if isinstance(run_id_or_result_subdir, int) else str(run_id_or_result_subdir) |
|
dirs = sorted(glob.glob(os.path.join(config.result_dir, searchdir, prefix + '-*'))) |
|
dirs = [dir for dir in dirs if os.path.isdir(dir)] |
|
if len(dirs) == 1: |
|
return dirs[0] |
|
raise IOError('Cannot locate result subdir for run', run_id_or_result_subdir) |
|
|
|
def list_network_pkls(run_id_or_result_subdir, include_final=True): |
|
result_subdir = locate_result_subdir(run_id_or_result_subdir) |
|
pkls = sorted(glob.glob(os.path.join(result_subdir, 'network-*.pkl'))) |
|
if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl': |
|
if include_final: |
|
pkls.append(pkls[0]) |
|
del pkls[0] |
|
return pkls |
|
|
|
def locate_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot=None): |
|
if isinstance(run_id_or_result_subdir_or_network_pkl, str) and os.path.isfile(run_id_or_result_subdir_or_network_pkl): |
|
return run_id_or_result_subdir_or_network_pkl |
|
|
|
pkls = list_network_pkls(run_id_or_result_subdir_or_network_pkl) |
|
if len(pkls) >= 1 and snapshot is None: |
|
return pkls[-1] |
|
for pkl in pkls: |
|
try: |
|
name = os.path.splitext(os.path.basename(pkl))[0] |
|
number = int(name.split('-')[-1]) |
|
if number == snapshot: |
|
return pkl |
|
except ValueError: pass |
|
except IndexError: pass |
|
raise IOError('Cannot locate network pkl for snapshot', snapshot) |
|
|
|
def get_id_string_for_network_pkl(network_pkl): |
|
p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/') |
|
return '-'.join(p[max(len(p) - 2, 0):]) |
|
|
|
|
|
|
|
|
|
def load_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot=None): |
|
return load_pkl(locate_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot)) |
|
|
|
def random_latents(num_latents, G, random_state=None): |
|
if random_state is not None: |
|
return random_state.randn(num_latents, *G.input_shape[1:]).astype(np.float32) |
|
else: |
|
return np.random.randn(num_latents, *G.input_shape[1:]).astype(np.float32) |
|
|
|
def load_dataset_for_previous_run(run_id, **kwargs): |
|
result_subdir = locate_result_subdir(run_id) |
|
|
|
|
|
parsed_cfg = dict() |
|
with open(os.path.join(result_subdir, 'config.txt'), 'rt') as f: |
|
for line in f: |
|
if line.startswith('dataset =') or line.startswith('train ='): |
|
exec(line, parsed_cfg, parsed_cfg) |
|
dataset_cfg = parsed_cfg.get('dataset', dict()) |
|
train_cfg = parsed_cfg.get('train', dict()) |
|
mirror_augment = train_cfg.get('mirror_augment', False) |
|
|
|
|
|
if 'h5_path' in dataset_cfg: |
|
dataset_cfg['tfrecord_dir'] = dataset_cfg.pop('h5_path').replace('.h5', '') |
|
if 'mirror_augment' in dataset_cfg: |
|
mirror_augment = dataset_cfg.pop('mirror_augment') |
|
if 'max_labels' in dataset_cfg: |
|
v = dataset_cfg.pop('max_labels') |
|
if v is None: v = 0 |
|
if v == 'all': v = 'full' |
|
dataset_cfg['max_label_size'] = v |
|
if 'max_images' in dataset_cfg: |
|
dataset_cfg.pop('max_images') |
|
|
|
|
|
v = dataset_cfg['tfrecord_dir'] |
|
v = v.replace('-32x32', '').replace('-32', '') |
|
v = v.replace('-128x128', '').replace('-128', '') |
|
v = v.replace('-256x256', '').replace('-256', '') |
|
v = v.replace('-1024x1024', '').replace('-1024', '') |
|
v = v.replace('celeba-hq', 'celebahq') |
|
v = v.replace('cifar-10', 'cifar10') |
|
v = v.replace('cifar-100', 'cifar100') |
|
v = v.replace('mnist-rgb', 'mnistrgb') |
|
v = re.sub('lsun-100k-([^-]*)', 'lsun-\\1-100k', v) |
|
v = re.sub('lsun-full-([^-]*)', 'lsun-\\1-full', v) |
|
dataset_cfg['tfrecord_dir'] = v |
|
|
|
|
|
dataset_cfg.update(kwargs) |
|
dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **dataset_cfg) |
|
return dataset_obj, mirror_augment |
|
|
|
def apply_mirror_augment(minibatch): |
|
mask = np.random.rand(minibatch.shape[0]) < 0.5 |
|
minibatch = np.array(minibatch) |
|
minibatch[mask] = minibatch[mask, :, :, ::-1] |
|
return minibatch |
|
|
|
|
|
|
|
|
|
_text_label_cache = OrderedDict() |
|
|
|
def draw_text_label(img, text, x, y, alignx=0.5, aligny=0.5, color=255, opacity=1.0, glow_opacity=1.0, **kwargs): |
|
color = np.array(color).flatten().astype(np.float32) |
|
assert img.ndim == 3 and img.shape[2] == color.size or color.size == 1 |
|
alpha, glow = setup_text_label(text, **kwargs) |
|
xx, yy = int(np.rint(x - alpha.shape[1] * alignx)), int(np.rint(y - alpha.shape[0] * aligny)) |
|
xb, yb = max(-xx, 0), max(-yy, 0) |
|
xe, ye = min(alpha.shape[1], img.shape[1] - xx), min(alpha.shape[0], img.shape[0] - yy) |
|
img = np.array(img) |
|
slice = img[yy+yb : yy+ye, xx+xb : xx+xe, :] |
|
slice[:] = slice * (1.0 - (1.0 - (1.0 - alpha[yb:ye, xb:xe]) * (1.0 - glow[yb:ye, xb:xe] * glow_opacity)) * opacity)[:, :, np.newaxis] |
|
slice[:] = slice + alpha[yb:ye, xb:xe, np.newaxis] * (color * opacity)[np.newaxis, np.newaxis, :] |
|
return img |
|
|
|
def setup_text_label(text, font='Calibri', fontsize=32, padding=6, glow_size=2.0, glow_coef=3.0, glow_exp=2.0, cache_size=100): |
|
|
|
key = (text, font, fontsize, padding, glow_size, glow_coef, glow_exp) |
|
if key in _text_label_cache: |
|
value = _text_label_cache[key] |
|
del _text_label_cache[key] |
|
_text_label_cache[key] = value |
|
return value |
|
|
|
|
|
while len(_text_label_cache) >= cache_size: |
|
_text_label_cache.popitem(last=False) |
|
|
|
|
|
import moviepy.editor |
|
alpha = moviepy.editor.TextClip(text, font=font, fontsize=fontsize).mask.make_frame(0) |
|
alpha = np.pad(alpha, padding, mode='constant', constant_values=0.0) |
|
glow = scipy.ndimage.gaussian_filter(alpha, glow_size) |
|
glow = 1.0 - np.maximum(1.0 - glow * glow_coef, 0.0) ** glow_exp |
|
|
|
|
|
value = (alpha, glow) |
|
_text_label_cache[key] = value |
|
return value |
|
|
|
|
|
|