|
import tensorflow as tf |
|
from tensorflow.contrib import slim |
|
import cv2 |
|
import os, random |
|
import numpy as np |
|
|
|
class ImageData: |
|
|
|
def __init__(self, load_size, channels, augment_flag): |
|
self.load_size = load_size |
|
self.channels = channels |
|
self.augment_flag = augment_flag |
|
|
|
def image_processing(self, filename): |
|
x = tf.read_file(filename) |
|
x_decode = tf.image.decode_jpeg(x, channels=self.channels) |
|
img = tf.image.resize_images(x_decode, [self.load_size, self.load_size]) |
|
img = tf.cast(img, tf.float32) / 127.5 - 1 |
|
|
|
if self.augment_flag : |
|
augment_size = self.load_size + (30 if self.load_size == 256 else 15) |
|
p = random.random() |
|
if p > 0.5: |
|
img = augmentation(img, augment_size) |
|
|
|
return img |
|
|
|
def load_test_data(image_path, size=256): |
|
img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
img = cv2.resize(img, dsize=(size, size)) |
|
|
|
img = np.expand_dims(img, axis=0) |
|
img = img/127.5 - 1 |
|
|
|
return img |
|
|
|
def augmentation(image, augment_size): |
|
seed = random.randint(0, 2 ** 31 - 1) |
|
ori_image_shape = tf.shape(image) |
|
image = tf.image.random_flip_left_right(image, seed=seed) |
|
image = tf.image.resize_images(image, [augment_size, augment_size]) |
|
image = tf.random_crop(image, ori_image_shape, seed=seed) |
|
return image |
|
|
|
def save_images(images, size, image_path): |
|
return imsave(inverse_transform(images), size, image_path) |
|
|
|
def inverse_transform(images): |
|
return ((images+1.) / 2) * 255.0 |
|
|
|
|
|
def imsave(images, size, path): |
|
images = merge(images, size) |
|
images = cv2.cvtColor(images.astype('uint8'), cv2.COLOR_RGB2BGR) |
|
|
|
return cv2.imwrite(path, images) |
|
|
|
def merge(images, size): |
|
h, w = images.shape[1], images.shape[2] |
|
img = np.zeros((h * size[0], w * size[1], 3)) |
|
for idx, image in enumerate(images): |
|
i = idx % size[1] |
|
j = idx // size[1] |
|
img[h*j:h*(j+1), w*i:w*(i+1), :] = image |
|
|
|
return img |
|
|
|
def show_all_variables(): |
|
model_vars = tf.trainable_variables() |
|
slim.model_analyzer.analyze_vars(model_vars, print_info=True) |
|
|
|
def check_folder(log_dir): |
|
if not os.path.exists(log_dir): |
|
os.makedirs(log_dir) |
|
return log_dir |
|
|
|
def str2bool(x): |
|
return x.lower() in ('true') |
|
|