|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Cross Convolutional Model. |
|
|
|
https://arxiv.org/pdf/1607.02586v1.pdf |
|
""" |
|
import math |
|
import sys |
|
|
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
slim = tf.contrib.slim |
|
|
|
|
|
class CrossConvModel(object): |
|
|
|
def __init__(self, image_diff_list, params): |
|
"""Constructor. |
|
|
|
Args: |
|
image_diff_list: A list of (image, diff) tuples, with shape |
|
[batch_size, image_size, image_size, 3] and image_sizes as |
|
[32, 64, 128, 256]. |
|
params: Dict of parameters. |
|
""" |
|
self.images = [i for (i, _) in image_diff_list] |
|
|
|
self.diffs = [(d + params['scale']) / 2 for (i, d) in image_diff_list] |
|
self.params = params |
|
|
|
def Build(self): |
|
with tf.device('/gpu:0'): |
|
with slim.arg_scope([slim.conv2d], |
|
activation_fn=tf.nn.relu, |
|
normalizer_fn=slim.batch_norm, |
|
normalizer_params={'is_training': |
|
self.params['is_training']}): |
|
self._BuildMotionKernel() |
|
encoded_images = self._BuildImageEncoder() |
|
cross_conved_images = self._CrossConv(encoded_images) |
|
self._BuildImageDecoder(cross_conved_images) |
|
self._BuildLoss() |
|
|
|
image = self.images[1] |
|
diff = self.diffs[1] |
|
|
|
self.global_step = tf.Variable(0, name='global_step', trainable=False) |
|
|
|
if self.params['is_training']: |
|
self._BuildTrainOp() |
|
|
|
diff = diff * 2.0 - self.params['scale'] |
|
diff_output = self.diff_output * 2.0 - self.params['scale'] |
|
concat_image = tf.concat( |
|
axis=1, values=[image, image + diff_output, image + diff, diff_output]) |
|
tf.summary.image('origin_predict_expect_predictdiff', concat_image) |
|
self.summary_op = tf.summary.merge_all() |
|
return self.loss |
|
|
|
def _BuildTrainOp(self): |
|
lrn_rate = tf.maximum( |
|
0.01, |
|
tf.train.exponential_decay( |
|
self.params['learning_rate'], self.global_step, 10000, 0.5)) |
|
tf.summary.scalar('learning rate', lrn_rate) |
|
optimizer = tf.train.GradientDescentOptimizer(lrn_rate) |
|
self.train_op = slim.learning.create_train_op( |
|
self.loss, optimizer, global_step=self.global_step) |
|
|
|
def _BuildLoss(self): |
|
|
|
|
|
|
|
self.loss = 0 |
|
with tf.variable_scope('loss'): |
|
if self.params['l2_loss']: |
|
l2_loss = tf.reduce_mean(tf.square(self.diff_output - self.diffs[1])) |
|
tf.summary.scalar('l2_loss', l2_loss) |
|
self.loss += l2_loss |
|
if self.params['reconstr_loss']: |
|
reconstr_loss = (-tf.reduce_mean( |
|
self.diffs[1] * (1e-10 + self.diff_output) + |
|
(1-self.diffs[1]) * tf.log(1e-10 + 1 - self.diff_output))) |
|
reconstr_loss = tf.check_numerics(reconstr_loss, 'reconstr_loss') |
|
tf.summary.scalar('reconstr_loss', reconstr_loss) |
|
self.loss += reconstr_loss |
|
if self.params['kl_loss']: |
|
kl_loss = (0.5 * tf.reduce_mean( |
|
tf.square(self.z_mean) + tf.square(self.z_stddev) - |
|
2 * self.z_stddev_log - 1)) |
|
tf.summary.scalar('kl_loss', kl_loss) |
|
self.loss += kl_loss |
|
|
|
tf.summary.scalar('loss', self.loss) |
|
|
|
def _BuildMotionKernel(self): |
|
image = self.images[-2] |
|
diff = self.diffs[-2] |
|
shape = image.get_shape().as_list() |
|
assert shape[1] == shape[2] and shape[1] == 128 |
|
batch_size = shape[0] |
|
|
|
net = tf.concat(axis=3, values=[image, diff]) |
|
with tf.variable_scope('motion_encoder'): |
|
with slim.arg_scope([slim.conv2d], padding='VALID'): |
|
net = slim.conv2d(net, 96, [5, 5], stride=1) |
|
net = slim.max_pool2d(net, [2, 2]) |
|
net = slim.conv2d(net, 96, [5, 5], stride=1) |
|
net = slim.max_pool2d(net, [2, 2]) |
|
net = slim.conv2d(net, 128, [5, 5], stride=1) |
|
net = slim.conv2d(net, 128, [5, 5], stride=1) |
|
net = slim.max_pool2d(net, [2, 2]) |
|
net = slim.conv2d(net, 256, [4, 4], stride=1) |
|
net = slim.conv2d(net, 256, [3, 3], stride=1) |
|
|
|
z = tf.reshape(net, shape=[batch_size, -1]) |
|
self.z_mean, self.z_stddev_log = tf.split( |
|
axis=1, num_or_size_splits=2, value=z) |
|
self.z_stddev = tf.exp(self.z_stddev_log) |
|
|
|
epsilon = tf.random_normal( |
|
self.z_mean.get_shape().as_list(), 0, 1, dtype=tf.float32) |
|
kernel = self.z_mean + tf.multiply(self.z_stddev, epsilon) |
|
|
|
width = int(math.sqrt(kernel.get_shape().as_list()[1] // 128)) |
|
kernel = tf.reshape(kernel, [batch_size, width, width, 128]) |
|
with tf.variable_scope('kernel_decoder'): |
|
with slim.arg_scope([slim.conv2d], padding='SAME'): |
|
kernel = slim.conv2d(kernel, 128, [5, 5], stride=1) |
|
self.kernel = slim.conv2d(kernel, 128, [5, 5], stride=1) |
|
|
|
sys.stderr.write('kernel shape: %s\n' % kernel.get_shape()) |
|
|
|
def _BuildImageEncoder(self): |
|
feature_maps = [] |
|
for (i, image) in enumerate(self.images): |
|
with tf.variable_scope('image_encoder_%d' % i): |
|
with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'): |
|
net = slim.conv2d(image, 64, [5, 5], stride=1) |
|
net = slim.conv2d(net, 64, [5, 5], stride=1) |
|
net = slim.max_pool2d(net, [5, 5]) |
|
net = slim.conv2d(net, 64, [5, 5], stride=1) |
|
net = slim.conv2d(net, 32, [5, 5], stride=1) |
|
net = slim.max_pool2d(net, [2, 2]) |
|
sys.stderr.write('image_conv shape: %s\n' % net.get_shape()) |
|
feature_maps.append(net) |
|
return feature_maps |
|
|
|
def _CrossConvHelper(self, encoded_image, kernel): |
|
"""Cross Convolution. |
|
|
|
The encoded image and kernel are of the same shape. Namely |
|
[batch_size, image_size, image_size, channels]. They are split |
|
into [image_size, image_size] image squares [kernel_size, kernel_size] |
|
kernel squares. kernel squares are used to convolute image squares. |
|
""" |
|
images = tf.expand_dims(encoded_image, 0) |
|
kernels = tf.expand_dims(kernel, 3) |
|
return tf.nn.depthwise_conv2d(images, kernels, [1, 1, 1, 1], 'SAME') |
|
|
|
def _CrossConv(self, encoded_images): |
|
"""Apply the motion kernel on the encoded_images.""" |
|
cross_conved_images = [] |
|
kernels = tf.split(axis=3, num_or_size_splits=4, value=self.kernel) |
|
for (i, encoded_image) in enumerate(encoded_images): |
|
with tf.variable_scope('cross_conv_%d' % i): |
|
kernel = kernels[i] |
|
|
|
encoded_image = tf.unstack(encoded_image, axis=0) |
|
kernel = tf.unstack(kernel, axis=0) |
|
assert len(encoded_image) == len(kernel) |
|
assert len(encoded_image) == self.params['batch_size'] |
|
conved_image = [] |
|
for j in xrange(len(encoded_image)): |
|
conved_image.append(self._CrossConvHelper( |
|
encoded_image[j], kernel[j])) |
|
cross_conved_images.append(tf.concat(axis=0, values=conved_image)) |
|
sys.stderr.write('cross_conved shape: %s\n' % |
|
cross_conved_images[-1].get_shape()) |
|
return cross_conved_images |
|
|
|
def _Deconv(self, net, out_filters, kernel_size, stride): |
|
shape = net.get_shape().as_list() |
|
in_filters = shape[3] |
|
kernel_shape = [kernel_size, kernel_size, out_filters, in_filters] |
|
|
|
weights = tf.get_variable( |
|
name='weights', |
|
shape=kernel_shape, |
|
dtype=tf.float32, |
|
initializer=tf.truncated_normal_initializer(stddev=0.01)) |
|
|
|
|
|
out_height = shape[1] * stride |
|
out_width = shape[2] * stride |
|
batch_size = shape[0] |
|
|
|
output_shape = [batch_size, out_height, out_width, out_filters] |
|
net = tf.nn.conv2d_transpose(net, weights, output_shape, |
|
[1, stride, stride, 1], padding='SAME') |
|
slim.batch_norm(net) |
|
return net |
|
|
|
def _BuildImageDecoder(self, cross_conved_images): |
|
"""Decode the cross_conved feature maps into the predicted images.""" |
|
nets = [] |
|
for i, cross_conved_image in enumerate(cross_conved_images): |
|
with tf.variable_scope('image_decoder_%d' % i): |
|
stride = 64 / cross_conved_image.get_shape().as_list()[1] |
|
|
|
nets.append(self._Deconv( |
|
cross_conved_image, 64, kernel_size=3, stride=stride)) |
|
|
|
net = tf.concat(axis=3, values=nets) |
|
net = slim.conv2d(net, 128, [9, 9], padding='SAME', stride=1) |
|
net = slim.conv2d(net, 128, [1, 1], padding='SAME', stride=1) |
|
net = slim.conv2d(net, 3, [1, 1], padding='SAME', stride=1) |
|
self.diff_output = net |
|
sys.stderr.write('diff_output shape: %s\n' % self.diff_output.get_shape()) |
|
|