Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras import losses | |
class SpatialConsistencyLoss(losses.Loss): | |
def __init__(self, **kwargs): | |
super(SpatialConsistencyLoss, self).__init__(reduction="none") | |
self.left_kernel = tf.constant( | |
[[[[0, 0, 0]], [[-1, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32 | |
) | |
self.right_kernel = tf.constant( | |
[[[[0, 0, 0]], [[0, 1, -1]], [[0, 0, 0]]]], dtype=tf.float32 | |
) | |
self.up_kernel = tf.constant( | |
[[[[0, -1, 0]], [[0, 1, 0]], [[0, 0, 0]]]], dtype=tf.float32 | |
) | |
self.down_kernel = tf.constant( | |
[[[[0, 0, 0]], [[0, 1, 0]], [[0, -1, 0]]]], dtype=tf.float32 | |
) | |
def call(self, y_true, y_pred): | |
original_mean = tf.reduce_mean(y_true, 3, keepdims=True) | |
enhanced_mean = tf.reduce_mean(y_pred, 3, keepdims=True) | |
original_pool = tf.nn.avg_pool2d( | |
original_mean, ksize=4, strides=4, padding="VALID" | |
) | |
enhanced_pool = tf.nn.avg_pool2d( | |
enhanced_mean, ksize=4, strides=4, padding="VALID" | |
) | |
d_original_left = tf.nn.conv2d( | |
original_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_original_right = tf.nn.conv2d( | |
original_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_original_up = tf.nn.conv2d( | |
original_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_original_down = tf.nn.conv2d( | |
original_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_enhanced_left = tf.nn.conv2d( | |
enhanced_pool, self.left_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_enhanced_right = tf.nn.conv2d( | |
enhanced_pool, self.right_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_enhanced_up = tf.nn.conv2d( | |
enhanced_pool, self.up_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_enhanced_down = tf.nn.conv2d( | |
enhanced_pool, self.down_kernel, strides=[1, 1, 1, 1], padding="SAME" | |
) | |
d_left = tf.square(d_original_left - d_enhanced_left) | |
d_right = tf.square(d_original_right - d_enhanced_right) | |
d_up = tf.square(d_original_up - d_enhanced_up) | |
d_down = tf.square(d_original_down - d_enhanced_down) | |
return d_left + d_right + d_up + d_down | |